@@ -143,6 +143,69 @@ def __post_init__(self):
143143 self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144144
145145
146+ class Rope (torch .nn .Module ):
147+ def __init__ (self , params : ModelArgs ):
148+ super ().__init__ ()
149+ self .params = params
150+ if self .params .use_hf_rope :
151+ self .precompute_freqs_cis = hf_precompute_freqs_cis
152+ else :
153+ self .precompute_freqs_cis = partial (
154+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
155+ )
156+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
157+ self .params .dim // self .params .n_heads ,
158+ (
159+ self .params .max_seq_len # Normal llama2.
160+ if self .params .ffn_dim_multiplier is None
161+ else self .params .max_seq_len * 2 # Sharded checkpoint.
162+ ),
163+ self .params .rope_freq_base ,
164+ )
165+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
166+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
167+ if self .params .use_hf_rope :
168+ self .apply_rotary_emb = hf_apply_rotary_emb
169+ else :
170+ self .apply_rotary_emb = RotaryEmbedding ()
171+
172+ def forward (
173+ self ,
174+ q : torch .Tensor ,
175+ k : torch .Tensor ,
176+ seq_len : int ,
177+ input_pos : Optional [torch .LongTensor ] = None ,
178+ ):
179+ if self .params .use_kv_cache :
180+ assert (
181+ input_pos is not None
182+ ), "input_pos must be provided when use_kv_cache is True"
183+
184+ if self .params .enable_dynamic_shape :
185+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
186+ input_pos_item = input_pos [- 1 ].item ()
187+ torch ._check_is_size (input_pos_item )
188+ torch ._check (input_pos_item < self .params .max_seq_len )
189+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
190+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
191+ # pyre-ignore: Incompatible parameter type [6]
192+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
193+ else :
194+ # When not using dynamic shape, use of the .item results in
195+ # symints, due to querying the data from tensor.
196+ # this path avoids that for mps backend, although probably mps backend
197+ # can support dynamic shape?
198+ freqs_cos = self .freqs_cos [input_pos ]
199+ freqs_sin = self .freqs_sin [input_pos ]
200+
201+ else :
202+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
203+ freqs_cos = self .freqs_cos [:seq_len ]
204+ freqs_sin = self .freqs_sin [:seq_len ]
205+ q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
206+ return q , k
207+
208+
146209class KVCache (nn .Module ):
147210 def __init__ (
148211 self ,
@@ -262,7 +325,7 @@ def forward(
262325
263326
264327class Attention (nn .Module ):
265- def __init__ (self , args : ModelArgs , layer_id : int ):
328+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
266329 super ().__init__ ()
267330 self .use_kv_cache = args .use_kv_cache
268331 self .n_heads = args .n_heads
@@ -284,6 +347,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284347
285348 self .layer_id = layer_id
286349
350+ self .rope = rope
351+
287352 causal_mask = torch .tril (
288353 torch .ones (
289354 self .max_seq_len ,
@@ -300,7 +365,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
300365 args .max_seq_len ,
301366 self .n_kv_heads ,
302367 self .head_dim ,
303- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
368+ not args .use_sdpa_with_kv_cache_op ,
369+ # if we are using the custom op don't transpose the cache. Expect untransposed q k v
304370 args .enable_dynamic_shape ,
305371 )
306372 self .SDPA = SDPA (
@@ -311,16 +377,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
311377 max_seq_len = self .max_seq_len ,
312378 enable_dynamic_shape = args .enable_dynamic_shape ,
313379 )
314- if args .use_hf_rope :
315- self .apply_rotary_emb = hf_apply_rotary_emb
316- else :
317- self .apply_rotary_emb = RotaryEmbedding ()
318380
319381 def forward (
320382 self ,
321383 x : torch .Tensor ,
322- freqs_cos : torch .Tensor ,
323- freqs_sin : torch .Tensor ,
324384 input_pos : Optional [torch .Tensor ] = None ,
325385 ):
326386 bsz , seqlen , _ = x .shape
@@ -333,7 +393,7 @@ def forward(
333393 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
334394
335395 # RoPE relative positional embeddings
336- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
396+ q , k = self .rope . forward (q , k , seqlen , input_pos )
337397
338398 if self .use_kv_cache :
339399 assert input_pos is not None
@@ -421,13 +481,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421481
422482
423483class TransformerBlock (nn .Module ):
424- def __init__ (self , layer_id : int , args : ModelArgs ):
484+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
425485 super ().__init__ ()
426486 self .use_kv_cache = args .use_kv_cache
427487 self .n_heads = args .n_heads
428488 self .dim = args .dim
429489 self .head_dim = args .dim // args .n_heads
430- self .attention = Attention (args , layer_id )
490+ self .attention = Attention (args , layer_id , rope )
431491 if args .moe :
432492 self .block_sparse_moe = MOEFeedForward (args )
433493 else :
@@ -456,33 +516,17 @@ def __init__(self, params: ModelArgs):
456516 self .n_layers = params .n_layers
457517
458518 self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
519+ self .rope = Rope (params )
459520 self .layers = torch .nn .ModuleList ()
460521 for layer_id in range (params .n_layers ):
461- self .layers .append (TransformerBlock (layer_id , params ))
522+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
462523 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
463524 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
464525 self .use_kv_cache = params .use_kv_cache
465526 self .generate_full_logits = params .generate_full_logits
466527 self .max_seq_len = params .max_seq_len
467528 self .input_prune_map = params .input_prune_map
468529 self .output_prune_map = params .output_prune_map
469- if params .use_hf_rope :
470- self .precompute_freqs_cis = hf_precompute_freqs_cis
471- else :
472- self .precompute_freqs_cis = partial (
473- precompute_freqs_cis , use_scaled = params .use_scaled_rope
474- )
475- freqs_cos , freqs_sin = self .precompute_freqs_cis (
476- params .dim // params .n_heads ,
477- (
478- params .max_seq_len # Normal llama2.
479- if params .ffn_dim_multiplier is None
480- else params .max_seq_len * 2 # Sharded checkpoint.
481- ),
482- params .rope_freq_base ,
483- )
484- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
485- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
486530
487531 def forward (
488532 self ,
@@ -498,42 +542,9 @@ def forward(
498542 )
499543 if tokens is not None and h is None :
500544 h = self .tok_embeddings (tokens )
501- seqlen = h .shape [1 ]
502-
503- if self .use_kv_cache :
504- assert (
505- input_pos is not None
506- ), "input_pos must be provided when use_kv_cache is True"
507-
508- if self .params .enable_dynamic_shape :
509- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510- input_pos_item = input_pos [- 1 ].item ()
511- torch ._check_is_size (input_pos_item )
512- torch ._check (input_pos_item < self .params .max_seq_len )
513- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
515- # pyre-ignore: Incompatible parameter type [6]
516- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
517- else :
518- # When not using dynamic shape, use of the .item results in
519- # symints, due to querying the data from tensor.
520- # this path avoids that for mps backend, although probably mps backend
521- # can support dynamic shape?
522- freqs_cos = self .freqs_cos [input_pos ]
523- freqs_sin = self .freqs_sin [input_pos ]
524-
525- else :
526- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
527- freqs_cos = self .freqs_cos [:seqlen ]
528- freqs_sin = self .freqs_sin [:seqlen ]
529545
530546 for layer in self .layers :
531- h = layer (
532- h ,
533- freqs_cos ,
534- freqs_sin ,
535- input_pos ,
536- )
547+ h = layer (h , input_pos )
537548
538549 if not self .generate_full_logits :
539550 # Only the last logit is used for the new generated token
0 commit comments