1212import torch
1313import torch .nn .functional as F
1414
15- from executorch .examples .models .lfm2 .short_conv import ShortConvBlock
1615from executorch .examples .models .llama .attention import (
1716 Attention ,
1817 ATTENTION_REGISTRY ,
@@ -87,10 +86,15 @@ def __init__(self, args: ModelArgs, attention: Attention):
8786 self .dim = args .dim
8887 self .head_dim = args .head_dim
8988 self .attention = attention
89+
90+ assert (
91+ args .hidden_dim is not None
92+ ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
9093 if args .moe :
9194 self .block_sparse_moe = MOEFeedForward (args )
9295 else :
9396 self .feed_forward = FeedForward (dim = args .dim , hidden_dim = args .hidden_dim )
97+
9498 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
9599 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
96100
@@ -245,6 +249,11 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
245249 for layer_id in range (model_args .n_layers ):
246250 # hybrid models define layer_types
247251 if model_args .layer_types and model_args .layer_types [layer_id ] == "conv" :
252+ from executorch .examples .models .lfm2 .short_conv import ShortConvBlock
253+
254+ assert (
255+ model_args .hidden_dim is not None
256+ ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
248257 layers .append (
249258 ShortConvBlock (
250259 dim = model_args .dim ,
@@ -253,7 +262,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
253262 )
254263 )
255264 else :
256- attention = cls (model_args , layer_id , rope , ** model_args .attention_kwargs )
265+ attention = cls (
266+ model_args , layer_id , rope , ** model_args .attention_kwargs
267+ ) # pyre-ignore[45]
257268 transformer_block = TransformerBlock (model_args , attention )
258269 layers .append (transformer_block )
259270
0 commit comments