@@ -143,16 +143,73 @@ def __post_init__(self):
143143 self .hidden_dim = find_multiple (hidden_dim , multiple_of )
144144
145145
146+ class Rope (torch .nn .Module ):
147+ def __init__ (self , params : ModelArgs ):
148+ super ().__init__ ()
149+ self .params = params
150+ if self .params .use_hf_rope :
151+ self .precompute_freqs_cis = hf_precompute_freqs_cis
152+ else :
153+ self .precompute_freqs_cis = partial (
154+ precompute_freqs_cis , use_scaled = self .params .use_scaled_rope
155+ )
156+ freqs_cos , freqs_sin = self .precompute_freqs_cis (
157+ self .params .dim // self .params .n_heads ,
158+ (
159+ self .params .max_seq_len # Normal llama2.
160+ if self .params .ffn_dim_multiplier is None
161+ else self .params .max_seq_len * 2 # Sharded checkpoint.
162+ ),
163+ self .params .rope_freq_base ,
164+ )
165+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
166+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
167+ if self .params .use_hf_rope :
168+ self .apply_rotary_emb = hf_apply_rotary_emb
169+ else :
170+ self .apply_rotary_emb = RotaryEmbedding ()
171+
172+ def forward (self , q : torch .Tensor , k : torch .Tensor , seq_len : int , input_pos : Optional [torch .LongTensor ] = None ):
173+ if self .params .use_kv_cache :
174+ assert (
175+ input_pos is not None
176+ ), "input_pos must be provided when use_kv_cache is True"
177+
178+ if self .params .enable_dynamic_shape :
179+ # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
180+ input_pos_item = input_pos [- 1 ].item ()
181+ torch ._check_is_size (input_pos_item )
182+ torch ._check (input_pos_item < self .params .max_seq_len )
183+ # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
184+ freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seq_len )
185+ # pyre-ignore: Incompatible parameter type [6]
186+ freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seq_len )
187+ else :
188+ # When not using dynamic shape, use of the .item results in
189+ # symints, due to querying the data from tensor.
190+ # this path avoids that for mps backend, although probably mps backend
191+ # can support dynamic shape?
192+ freqs_cos = self .freqs_cos [input_pos ]
193+ freqs_sin = self .freqs_sin [input_pos ]
194+
195+ else :
196+ assert input_pos is None , "input_pos is unused when use_kv_cache is False"
197+ freqs_cos = self .freqs_cos [:seq_len ]
198+ freqs_sin = self .freqs_sin [:seq_len ]
199+ q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
200+ return q , k
201+
202+
146203class KVCache (nn .Module ):
147204 def __init__ (
148- self ,
149- max_batch_size : int ,
150- max_seq_length : int ,
151- n_heads : int ,
152- head_dim : int ,
153- transpose_cache : bool ,
154- enable_dynamic_shape : bool ,
155- dtype = torch .float32 ,
205+ self ,
206+ max_batch_size : int ,
207+ max_seq_length : int ,
208+ n_heads : int ,
209+ head_dim : int ,
210+ transpose_cache : bool ,
211+ enable_dynamic_shape : bool ,
212+ dtype = torch .float32 ,
156213 ):
157214 super ().__init__ ()
158215 self .max_seq_length = max_seq_length
@@ -175,7 +232,7 @@ def __init__(
175232 )
176233
177234 def update (
178- self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
235+ self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
179236 ) -> Tuple [torch .Tensor , torch .Tensor ]:
180237 # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
181238 if self .enable_dynamic_shape :
@@ -213,13 +270,13 @@ def update(
213270
214271class SDPA (nn .Module ):
215272 def __init__ (
216- self ,
217- kv_cache : KVCache ,
218- dim : int ,
219- head_dim : int ,
220- n_rep : int ,
221- max_seq_len : int ,
222- enable_dynamic_shape : bool ,
273+ self ,
274+ kv_cache : KVCache ,
275+ dim : int ,
276+ head_dim : int ,
277+ n_rep : int ,
278+ max_seq_len : int ,
279+ enable_dynamic_shape : bool ,
223280 ):
224281 super ().__init__ ()
225282 self .kv_cache = kv_cache
@@ -230,14 +287,14 @@ def __init__(
230287 self .enable_dynamic_shape = enable_dynamic_shape
231288
232289 def forward (
233- self ,
234- input_pos : torch .Tensor ,
235- q : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
236- k : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
237- v : torch .Tensor , # (bs, seqlen, n_local_kv_heads, head_dim)
238- bsz ,
239- seqlen ,
240- mask : torch .Tensor ,
290+ self ,
291+ input_pos : torch .Tensor ,
292+ q : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
293+ k : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
294+ v : torch .Tensor , # (bs, seqlen, n_local_kv_heads, head_dim)
295+ bsz ,
296+ seqlen ,
297+ mask : torch .Tensor ,
241298 ) -> torch .Tensor :
242299 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
243300 k = k .transpose (1 , 2 )
@@ -262,7 +319,7 @@ def forward(
262319
263320
264321class Attention (nn .Module ):
265- def __init__ (self , args : ModelArgs , layer_id : int ):
322+ def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
266323 super ().__init__ ()
267324 self .use_kv_cache = args .use_kv_cache
268325 self .n_heads = args .n_heads
@@ -284,6 +341,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
284341
285342 self .layer_id = layer_id
286343
344+ self .rope = rope
345+
287346 causal_mask = torch .tril (
288347 torch .ones (
289348 self .max_seq_len ,
@@ -300,7 +359,8 @@ def __init__(self, args: ModelArgs, layer_id: int):
300359 args .max_seq_len ,
301360 self .n_kv_heads ,
302361 self .head_dim ,
303- not args .use_sdpa_with_kv_cache_op , # if we are using the custom op dont transpose the cache. Expect untransposed q k v
362+ not args .use_sdpa_with_kv_cache_op ,
363+ # if we are using the custom op don't transpose the cache. Expect untransposed q k v
304364 args .enable_dynamic_shape ,
305365 )
306366 self .SDPA = SDPA (
@@ -311,17 +371,11 @@ def __init__(self, args: ModelArgs, layer_id: int):
311371 max_seq_len = self .max_seq_len ,
312372 enable_dynamic_shape = args .enable_dynamic_shape ,
313373 )
314- if args .use_hf_rope :
315- self .apply_rotary_emb = hf_apply_rotary_emb
316- else :
317- self .apply_rotary_emb = RotaryEmbedding ()
318374
319375 def forward (
320- self ,
321- x : torch .Tensor ,
322- freqs_cos : torch .Tensor ,
323- freqs_sin : torch .Tensor ,
324- input_pos : Optional [torch .Tensor ] = None ,
376+ self ,
377+ x : torch .Tensor ,
378+ input_pos : Optional [torch .Tensor ] = None ,
325379 ):
326380 bsz , seqlen , _ = x .shape
327381
@@ -333,7 +387,7 @@ def forward(
333387 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
334388
335389 # RoPE relative positional embeddings
336- q , k = self .apply_rotary_emb (q , k , freqs_cos , freqs_sin )
390+ q , k = self .rope . forward (q , k , seqlen , input_pos )
337391
338392 if self .use_kv_cache :
339393 assert input_pos is not None
@@ -421,13 +475,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
421475
422476
423477class TransformerBlock (nn .Module ):
424- def __init__ (self , layer_id : int , args : ModelArgs ):
478+ def __init__ (self , layer_id : int , args : ModelArgs , rope : Rope ):
425479 super ().__init__ ()
426480 self .use_kv_cache = args .use_kv_cache
427481 self .n_heads = args .n_heads
428482 self .dim = args .dim
429483 self .head_dim = args .dim // args .n_heads
430- self .attention = Attention (args , layer_id )
484+ self .attention = Attention (args , layer_id , rope )
431485 if args .moe :
432486 self .block_sparse_moe = MOEFeedForward (args )
433487 else :
@@ -456,84 +510,35 @@ def __init__(self, params: ModelArgs):
456510 self .n_layers = params .n_layers
457511
458512 self .tok_embeddings = nn .Embedding (params .vocab_size , params .dim )
513+ self .rope = Rope (params )
459514 self .layers = torch .nn .ModuleList ()
460515 for layer_id in range (params .n_layers ):
461- self .layers .append (TransformerBlock (layer_id , params ))
516+ self .layers .append (TransformerBlock (layer_id , params , self . rope ))
462517 self .norm = RMSNorm (params .dim , eps = params .norm_eps )
463518 self .output = nn .Linear (params .dim , params .vocab_size , bias = False )
464519 self .use_kv_cache = params .use_kv_cache
465520 self .generate_full_logits = params .generate_full_logits
466521 self .max_seq_len = params .max_seq_len
467522 self .input_prune_map = params .input_prune_map
468523 self .output_prune_map = params .output_prune_map
469- if params .use_hf_rope :
470- self .precompute_freqs_cis = hf_precompute_freqs_cis
471- else :
472- self .precompute_freqs_cis = partial (
473- precompute_freqs_cis , use_scaled = params .use_scaled_rope
474- )
475- freqs_cos , freqs_sin = self .precompute_freqs_cis (
476- params .dim // params .n_heads ,
477- (
478- params .max_seq_len # Normal llama2.
479- if params .ffn_dim_multiplier is None
480- else params .max_seq_len * 2 # Sharded checkpoint.
481- ),
482- params .rope_freq_base ,
483- )
484- self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
485- self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
486524
487525 def forward (
488- self ,
489- tokens : Optional [torch .LongTensor ] = None , # tokens
490- input_pos : Optional [
491- torch .LongTensor
492- ] = None , # Scalar tensor indicating size of window of the caches
493- h : Optional [torch .FloatTensor ] = None , # embeddings
526+ self ,
527+ tokens : Optional [torch .LongTensor ] = None , # tokens
528+ input_pos : Optional [
529+ torch .LongTensor
530+ ] = None , # Scalar tensor indicating size of window of the caches
531+ h : Optional [torch .FloatTensor ] = None , # embeddings
494532 ) -> torch .Tensor :
495533 if (tokens is None ) ^ (h is not None ):
496534 raise ValueError (
497535 "You cannot specify both tokens and h at the same time, and must specify either one"
498536 )
499537 if tokens is not None and h is None :
500538 h = self .tok_embeddings (tokens )
501- seqlen = h .shape [1 ]
502-
503- if self .use_kv_cache :
504- assert (
505- input_pos is not None
506- ), "input_pos must be provided when use_kv_cache is True"
507-
508- if self .params .enable_dynamic_shape :
509- # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
510- input_pos_item = input_pos [- 1 ].item ()
511- torch ._check_is_size (input_pos_item )
512- torch ._check (input_pos_item < self .params .max_seq_len )
513- # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
514- freqs_cos = self .freqs_cos .narrow (0 , input_pos_item , seqlen )
515- # pyre-ignore: Incompatible parameter type [6]
516- freqs_sin = self .freqs_sin .narrow (0 , input_pos_item , seqlen )
517- else :
518- # When not using dynamic shape, use of the .item results in
519- # symints, due to querying the data from tensor.
520- # this path avoids that for mps backend, although probably mps backend
521- # can support dynamic shape?
522- freqs_cos = self .freqs_cos [input_pos ]
523- freqs_sin = self .freqs_sin [input_pos ]
524-
525- else :
526- assert input_pos is None , "input_pos is unused when use_kv_cache is False"
527- freqs_cos = self .freqs_cos [:seqlen ]
528- freqs_sin = self .freqs_sin [:seqlen ]
529539
530540 for layer in self .layers :
531- h = layer (
532- h ,
533- freqs_cos ,
534- freqs_sin ,
535- input_pos ,
536- )
541+ h = layer (h , input_pos )
537542
538543 if not self .generate_full_logits :
539544 # Only the last logit is used for the new generated token
0 commit comments