@@ -143,6 +143,69 @@ def __post_init__(self):
143143 self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144144
145145
146+ class Rope (torch .nn .Module ):
147+ def __init__ (self , params : ModelArgs ):
148+ super ().__init__ ()
149+ self .params = params
150+ if self .params .use_hf_rope :
151+ self .precompute_freqs_cis = hf_precompute_freqs_cis
152+ else :
153+ self .precompute_freqs_cis = partial (
154+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
155+ )
156+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
157+ self .params .dim // self .params .n_heads ,
158+ (
159+ self .params .max_seq_len # Normal llama2.
160+ if self .params .ffn_dim_multiplier is None
161+ else self .params .max_seq_len * 2 # Sharded checkpoint.
162+ ),
163+ self .params .rope_freq_base ,
164+ )
165+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
166+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
167+ if self .params .use_hf_rope :
168+ self .apply_rotary_emb = hf_apply_rotary_emb
169+ else :
170+ self .apply_rotary_emb = RotaryEmbedding ()
171+
172+ def forward (
173+ self ,
174+ q : torch .Tensor ,
175+ k : torch .Tensor ,
176+ seq_len : int ,
177+ input_pos : Optional [torch .Tensor ] = None ,
178+ ):
179+ if self .params .use_kv_cache :
180+ assert (
181+ input_pos is not None
182+ ), "input_pos must be provided when use_kv_cache is True"
183+
184+ if self .params .enable_dynamic_shape :
185+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
186+ input_pos_item = input_pos [- 1 ].item ()
187+ torch ._check_is_size (input_pos_item )
188+ torch ._check (input_pos_item < self .params .max_seq_len )
189+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
190+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
191+ # pyre-ignore: Incompatible parameter type [6]
192+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
193+ else :
194+ # When not using dynamic shape, use of the .item results in
195+ # symints, due to querying the data from tensor.
196+ # this path avoids that for mps backend, although probably mps backend
197+ # can support dynamic shape?
198+ freqs_cos = self .freqs_cos [input_pos ]
199+ freqs_sin = self .freqs_sin [input_pos ]
200+
201+ else :
202+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
203+ freqs_cos = self .freqs_cos [:seq_len ]
204+ freqs_sin = self .freqs_sin [:seq_len ]
205+ q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
206+ return q , k
207+
208+
146209class KVCache (nn .Module ):
147210 def __init__ (
148211 self ,
@@ -262,7 +325,7 @@ def forward(
262325
263326
264327class Attention (nn .Module ):
265- def __init__ (self , args : ModelArgs , layer_id : int ):
328+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
266329 super ().__init__ ()
267330 self .use_kv_cache = args .use_kv_cache
268331 self .n_heads = args .n_heads
@@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284347
285348 self .layer_id = layer_id
286349
350+ self .rope = rope
351+
287352 causal_mask = torch .tril (
288353 torch .ones (
289354 self .max_seq_len ,
@@ -300,7 +365,7 @@ def __init__(self, args: ModelArgs, layer_id: int):
300365 args .max_seq_len ,
301366 self .n_kv_heads ,
302367 self .head_dim ,
303- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
368+ not args .use_sdpa_with_kv_cache_op , # if we are using the custom op don't transpose the cache. Expect untransposed q k v
304369 args .enable_dynamic_shape ,
305370 )
306371 self .SDPA = SDPA (
@@ -311,16 +376,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
311376 max_seq_len = self .max_seq_len ,
312377 enable_dynamic_shape = args .enable_dynamic_shape ,
313378 )
314- if args .use_hf_rope :
315- self .apply_rotary_emb = hf_apply_rotary_emb
316- else :
317- self .apply_rotary_emb = RotaryEmbedding ()
318379
319380 def forward (
320381 self ,
321382 x : torch .Tensor ,
322- freqs_cos : torch .Tensor ,
323- freqs_sin : torch .Tensor ,
324383 input_pos : Optional [torch .Tensor ] = None ,
325384 ):
326385 bsz , seqlen , _ = x .shape
@@ -333,7 +392,7 @@ def forward(
333392 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
334393
335394 # RoPE relative positional embeddings
336- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
395+ q , k = self .rope . forward (q , k , seqlen , input_pos )
337396
338397 if self .use_kv_cache :
339398 assert input_pos is not None
@@ -421,24 +480,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421480
422481
423482class TransformerBlock (nn .Module ):
424- def __init__ (self , layer_id : int , args : ModelArgs ):
483+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
425484 super ().__init__ ()
426485 self .use_kv_cache = args .use_kv_cache
427486 self .n_heads = args .n_heads
428487 self .dim = args .dim
429488 self .head_dim = args .dim // args .n_heads
430- self .attention = Attention (args , layer_id )
489+ self .attention = Attention (args , layer_id , rope )
431490 if args .moe :
432491 self .block_sparse_moe = MOEFeedForward (args )
433492 else :
434493 self .feed_forward = FeedForward (args )
435494 self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
436495 self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
437496
438- def forward (self , x , freqs_cos , freqs_sin , input_pos = None ): # x: 1xN
439- h = self .attention .forward (
440- self .attention_norm (x ), freqs_cos , freqs_sin , input_pos
441- )
497+ def forward (self , x , input_pos = None ): # x: 1xN
498+ h = self .attention .forward (self .attention_norm (x ), input_pos )
442499
443500 h = x + h
444501 if hasattr (self , "block_sparse_moe" ):
@@ -456,33 +513,17 @@ def __init__(self, params: ModelArgs):
456513 self .n_layers = params .n_layers
457514
458515 self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
516+ self .rope = Rope (params )
459517 self .layers = torch .nn .ModuleList ()
460518 for layer_id in range (params .n_layers ):
461- self .layers .append (TransformerBlock (layer_id , params ))
519+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
462520 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
463521 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
464522 self .use_kv_cache = params .use_kv_cache
465523 self .generate_full_logits = params .generate_full_logits
466524 self .max_seq_len = params .max_seq_len
467525 self .input_prune_map = params .input_prune_map
468526 self .output_prune_map = params .output_prune_map
469- if params .use_hf_rope :
470- self .precompute_freqs_cis = hf_precompute_freqs_cis
471- else :
472- self .precompute_freqs_cis = partial (
473- precompute_freqs_cis , use_scaled = params .use_scaled_rope
474- )
475- freqs_cos , freqs_sin = self .precompute_freqs_cis (
476- params .dim // params .n_heads ,
477- (
478- params .max_seq_len # Normal llama2.
479- if params .ffn_dim_multiplier is None
480- else params .max_seq_len * 2 # Sharded checkpoint.
481- ),
482- params .rope_freq_base ,
483- )
484- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
485- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
486527
487528 def forward (
488529 self ,
@@ -498,42 +539,9 @@ def forward(
498539 )
499540 if tokens is not None and h is None :
500541 h = self .tok_embeddings (tokens )
501- seqlen = h .shape [1 ]
502-
503- if self .use_kv_cache :
504- assert (
505- input_pos is not None
506- ), "input_pos must be provided when use_kv_cache is True"
507-
508- if self .params .enable_dynamic_shape :
509- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510- input_pos_item = input_pos [- 1 ].item ()
511- torch ._check_is_size (input_pos_item )
512- torch ._check (input_pos_item < self .params .max_seq_len )
513- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
515- # pyre-ignore: Incompatible parameter type [6]
516- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
517- else :
518- # When not using dynamic shape, use of the .item results in
519- # symints, due to querying the data from tensor.
520- # this path avoids that for mps backend, although probably mps backend
521- # can support dynamic shape?
522- freqs_cos = self .freqs_cos [input_pos ]
523- freqs_sin = self .freqs_sin [input_pos ]
524-
525- else :
526- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
527- freqs_cos = self .freqs_cos [:seqlen ]
528- freqs_sin = self .freqs_sin [:seqlen ]
529542
530543 for layer in self .layers :
531- h = layer (
532- h ,
533- freqs_cos ,
534- freqs_sin ,
535- input_pos ,
536- )
544+ h = layer (h , input_pos )
537545
538546 if not self .generate_full_logits :
539547 # Only the last logit is used for the new generated token
0 commit comments