1313import torch .nn .functional as F
1414
1515from executorch .examples .models .llama .attention import (
16+ Attention ,
1617 ATTENTION_REGISTRY ,
1718 ForwardOptions ,
1819)
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384
8485
8586class TransformerBlock (nn .Module ):
86- def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
87+ def __init__ (self , args : ModelArgs , attention : Attention ):
88+ """
89+ Transformer block with support for pre-norm and post-norm.
90+ Args:
91+ args (ModelArgs): model configuration parameters.
92+ attention (Attention): attention object to use in the transformer
93+ block. See `attention.py` for types of attention. Make sure
94+ the attention type is registered in the ATTENTION_REGISTRY.
95+ """
8796 super ().__init__ ()
8897 self .use_kv_cache = args .use_kv_cache
8998 self .n_heads = args .n_heads
9099 self .dim = args .dim
91100 self .head_dim = args .head_dim
92- if args .attention_type not in ATTENTION_REGISTRY :
93- raise ValueError (
94- f"Unknown attention type: { args .attention_type } . "
95- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
96- )
97- cls = ATTENTION_REGISTRY [args .attention_type ]
98- self .attention = cls (args , layer_id , rope )
101+ self .attention = attention
99102 if args .moe :
100103 self .block_sparse_moe = MOEFeedForward (args )
101104 else :
102105 self .feed_forward = FeedForward (args )
103106 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
104107 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
105108
109+ @classmethod
110+ def from_type (cls , layer_id , args , rope ) -> "TransformerBlock" :
111+ """
112+ Create a TransformerBlock with the legacy constructor.
113+ Args:
114+ layer_id (int): the index of the layer.
115+ args (ModelArgs): model configuration parameters.
116+ rope (Rope): the rope object to use for rotary embeddings.
117+ """
118+ if args .attention_type not in ATTENTION_REGISTRY :
119+ raise ValueError (
120+ f"Unknown attention type: { args .attention_type } . "
121+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
122+ )
123+ cls = ATTENTION_REGISTRY [args .attention_type ]
124+ attention = cls (args , layer_id , rope )
125+ return TransformerBlock (args , attention )
126+
106127 def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
107128 h , attn_options_update = self .attention .forward (
108129 self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117138
118139
119140class Transformer (nn .Module ):
120- def __init__ (self , params : ModelArgs ):
141+ def __init__ (self , params : ModelArgs , layers : nn .ModuleList , rope : Rope ):
142+ """
143+ Transformer model.
144+ Args:
145+ params (ModelArgs): model configuration parameters.
146+ layers (nn.ModuleList): list of transformer blocks - see the
147+ `TransformerBlock` type above.
148+ rope (Rope): the rope object to use for rotary embeddings.
149+ """
121150 super ().__init__ ()
122151 self .params = params
123152 self .vocab_size = params .vocab_size
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs):
130159 if self .apply_embedding
131160 else None
132161 )
133- self .rope = Rope (params )
134- self .layers = torch .nn .ModuleList ()
135- for layer_id in range (params .n_layers ):
136- self .layers .append (TransformerBlock (layer_id , params , self .rope ))
162+ self .layers = layers
163+ self .rope = rope
137164 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
138165 self .output = (
139166 nn .Linear (params .dim , params .vocab_size , bias = False )
@@ -212,3 +239,23 @@ def forward(
212239 return logits , attn_options_update
213240
214241 return logits
242+
243+
244+ def construct_transformer (model_args : ModelArgs ) -> Transformer :
245+ """
246+ Construct a Transformer model from the given model arguments.
247+ """
248+ rope = Rope (model_args )
249+ if model_args .attention_type not in ATTENTION_REGISTRY :
250+ raise ValueError (
251+ f"Unknown attention type: { model_args .attention_type } . "
252+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
253+ )
254+ layers = torch .nn .ModuleList ()
255+ cls = ATTENTION_REGISTRY [model_args .attention_type ]
256+ for layer_id in range (model_args .n_layers ):
257+ attention = cls (model_args , layer_id , rope )
258+ transformer_block = TransformerBlock (model_args , attention )
259+ layers .append (transformer_block )
260+
261+ return Transformer (model_args , layers , rope )
0 commit comments