12
12
import torch
13
13
import torch .nn .functional as F
14
14
15
- from executorch .examples .models .lfm2 .short_conv import ShortConvBlock
16
15
from executorch .examples .models .llama .attention import (
17
16
Attention ,
18
17
ATTENTION_REGISTRY ,
@@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention):
87
86
self .dim = args .dim
88
87
self .head_dim = args .head_dim
89
88
self .attention = attention
89
+
90
+ assert (
91
+ args .hidden_dim is not None
92
+ ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
90
93
if args .moe :
91
94
self .block_sparse_moe = MOEFeedForward (args )
92
95
else :
93
96
self .feed_forward = FeedForward (dim = args .dim , hidden_dim = args .hidden_dim )
97
+
94
98
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
95
99
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
96
100
@@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
245
249
for layer_id in range (model_args .n_layers ):
246
250
# hybrid models define layer_types
247
251
if model_args .layer_types and model_args .layer_types [layer_id ] == "conv" :
252
+ from executorch .examples .models .lfm2 .short_conv import ShortConvBlock
253
+
254
+ assert (
255
+ model_args .hidden_dim is not None
256
+ ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
248
257
layers .append (
249
258
ShortConvBlock (
250
259
dim = model_args .dim ,
@@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
253
262
)
254
263
)
255
264
else :
256
- attention = cls (model_args , layer_id , rope , ** model_args .attention_kwargs )
265
+ attention = cls (
266
+ model_args , layer_id , rope , ** model_args .attention_kwargs
267
+ ) # pyre-ignore[45]
257
268
transformer_block = TransformerBlock (model_args , attention )
258
269
layers .append (transformer_block )
259
270
0 commit comments