1212import torch
1313import torch .nn .functional as F
1414
15- from executorch .examples .models .llama .attention import (
16- ATTENTION_REGISTRY ,
17- ForwardOptions ,
18- )
15+ from executorch .examples .models .llama .attention import Attention , ForwardOptions
1916
2017from executorch .examples .models .llama .model_args import ModelArgs
2118from executorch .examples .models .llama .norm import RMSNorm
@@ -83,19 +80,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8380
8481
8582class TransformerBlock (nn .Module ):
86- def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
83+ def __init__ (self , args : ModelArgs , attention : Attention ):
8784 super ().__init__ ()
8885 self .use_kv_cache = args .use_kv_cache
8986 self .n_heads = args .n_heads
9087 self .dim = args .dim
9188 self .head_dim = args .head_dim
92- if args .attention_type not in ATTENTION_REGISTRY :
93- raise ValueError (
94- f"Unknown attention type: { args .attention_type } . "
95- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
96- )
97- cls = ATTENTION_REGISTRY [args .attention_type ]
98- self .attention = cls (args , layer_id , rope )
89+ self .attention = attention
9990 if args .moe :
10091 self .block_sparse_moe = MOEFeedForward (args )
10192 else :
@@ -117,7 +108,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117108
118109
119110class Transformer (nn .Module ):
120- def __init__ (self , params : ModelArgs ):
111+ def __init__ (self , params : ModelArgs , layers : nn . ModuleList , rope : Rope ):
121112 super ().__init__ ()
122113 self .params = params
123114 self .vocab_size = params .vocab_size
@@ -130,10 +121,8 @@ def __init__(self, params: ModelArgs):
130121 if self .apply_embedding
131122 else None
132123 )
133- self .rope = Rope (params )
134- self .layers = torch .nn .ModuleList ()
135- for layer_id in range (params .n_layers ):
136- self .layers .append (TransformerBlock (layer_id , params , self .rope ))
124+ self .layers = layers
125+ self .rope = rope
137126 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
138127 self .output = (
139128 nn .Linear (params .dim , params .vocab_size , bias = False )
0 commit comments