1919
2020from executorch .examples .models .llama .model_args import ModelArgs
2121from executorch .examples .models .llama .norm import RMSNorm
22- from executorch .examples .models .llama .rope import Rope
22+
23+ # from executorch.examples.models.llama.rope import Rope
2324from torch import nn
2425
2526
@@ -83,7 +84,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384
8485
8586class TransformerBlock (nn .Module ):
86- def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
87+ # def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
88+ def __init__ (self , args : ModelArgs , attention : Attention ):
8789 super ().__init__ ()
8890 self .use_kv_cache = args .use_kv_cache
8991 self .n_heads = args .n_heads
@@ -94,8 +96,11 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
9496 f"Unknown attention type: { args .attention_type } . "
9597 f"Available: { list (ATTENTION_REGISTRY .keys ())} "
9698 )
97- cls = ATTENTION_REGISTRY [args .attention_type ]
98- self .attention = cls (args , layer_id , rope )
99+
100+ self .attention = attention
101+
102+ # cls = ATTENTION_REGISTRY[args.attention_type]
103+ # self.attention = cls(args, layer_id, rope)
99104 if args .moe :
100105 self .block_sparse_moe = MOEFeedForward (args )
101106 else :
@@ -117,7 +122,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117122
118123
119124class Transformer (nn .Module ):
120- def __init__ (self , params : ModelArgs ):
125+ def __init__ (self , params : ModelArgs , layers : nn . ModuleList ):
121126 super ().__init__ ()
122127 self .params = params
123128 self .vocab_size = params .vocab_size
@@ -130,10 +135,11 @@ def __init__(self, params: ModelArgs):
130135 if self .apply_embedding
131136 else None
132137 )
133- self .rope = Rope (params )
134- self .layers = torch .nn .ModuleList ()
135- for layer_id in range (params .n_layers ):
136- self .layers .append (TransformerBlock (layer_id , params , self .rope ))
138+ # self.rope = Rope(params)
139+ # self.layers = torch.nn.ModuleList()
140+ self .layers = layers
141+ # for layer_id in range(params.n_layers):
142+ # self.layers.append(TransformerBlock(layer_id, params, self.rope))
137143 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
138144 self .output = (
139145 nn .Linear (params .dim , params .vocab_size , bias = False )
0 commit comments