@@ -169,10 +169,16 @@ def __init__(self, params: ModelArgs):
169169 else :
170170 self .apply_rotary_emb = RotaryEmbedding ()
171171
172- def forward (self , q : torch .Tensor , k : torch .Tensor , seq_len : int , input_pos : Optional [torch .LongTensor ] = None ):
172+ def forward (
173+ self ,
174+ q : torch .Tensor ,
175+ k : torch .Tensor ,
176+ seq_len : int ,
177+ input_pos : Optional [torch .LongTensor ] = None ,
178+ ):
173179 if self .params .use_kv_cache :
174180 assert (
175- input_pos is not None
181+ input_pos is not None
176182 ), "input_pos must be provided when use_kv_cache is True"
177183
178184 if self .params .enable_dynamic_shape :
@@ -202,14 +208,14 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int, input_pos: Opt
202208
203209class KVCache (nn .Module ):
204210 def __init__ (
205- self ,
206- max_batch_size : int ,
207- max_seq_length : int ,
208- n_heads : int ,
209- head_dim : int ,
210- transpose_cache : bool ,
211- enable_dynamic_shape : bool ,
212- dtype = torch .float32 ,
211+ self ,
212+ max_batch_size : int ,
213+ max_seq_length : int ,
214+ n_heads : int ,
215+ head_dim : int ,
216+ transpose_cache : bool ,
217+ enable_dynamic_shape : bool ,
218+ dtype = torch .float32 ,
213219 ):
214220 super ().__init__ ()
215221 self .max_seq_length = max_seq_length
@@ -232,7 +238,7 @@ def __init__(
232238 )
233239
234240 def update (
235- self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
241+ self , input_pos : torch .Tensor , k_val : torch .Tensor , v_val : torch .Tensor
236242 ) -> Tuple [torch .Tensor , torch .Tensor ]:
237243 # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache
238244 if self .enable_dynamic_shape :
@@ -270,13 +276,13 @@ def update(
270276
271277class SDPA (nn .Module ):
272278 def __init__ (
273- self ,
274- kv_cache : KVCache ,
275- dim : int ,
276- head_dim : int ,
277- n_rep : int ,
278- max_seq_len : int ,
279- enable_dynamic_shape : bool ,
279+ self ,
280+ kv_cache : KVCache ,
281+ dim : int ,
282+ head_dim : int ,
283+ n_rep : int ,
284+ max_seq_len : int ,
285+ enable_dynamic_shape : bool ,
280286 ):
281287 super ().__init__ ()
282288 self .kv_cache = kv_cache
@@ -287,14 +293,14 @@ def __init__(
287293 self .enable_dynamic_shape = enable_dynamic_shape
288294
289295 def forward (
290- self ,
291- input_pos : torch .Tensor ,
292- q : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
293- k : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
294- v : torch .Tensor , # (bs, seqlen, n_local_kv_heads, head_dim)
295- bsz ,
296- seqlen ,
297- mask : torch .Tensor ,
296+ self ,
297+ input_pos : torch .Tensor ,
298+ q : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim)
299+ k : torch .Tensor , # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim)
300+ v : torch .Tensor , # (bs, seqlen, n_local_kv_heads, head_dim)
301+ bsz ,
302+ seqlen ,
303+ mask : torch .Tensor ,
298304 ) -> torch .Tensor :
299305 q = q .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
300306 k = k .transpose (1 , 2 )
@@ -373,9 +379,9 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
373379 )
374380
375381 def forward (
376- self ,
377- x : torch .Tensor ,
378- input_pos : Optional [torch .Tensor ] = None ,
382+ self ,
383+ x : torch .Tensor ,
384+ input_pos : Optional [torch .Tensor ] = None ,
379385 ):
380386 bsz , seqlen , _ = x .shape
381387
@@ -523,12 +529,12 @@ def __init__(self, params: ModelArgs):
523529 self .output_prune_map = params .output_prune_map
524530
525531 def forward (
526- self ,
527- tokens : Optional [torch .LongTensor ] = None , # tokens
528- input_pos : Optional [
529- torch .LongTensor
530- ] = None , # Scalar tensor indicating size of window of the caches
531- h : Optional [torch .FloatTensor ] = None , # embeddings
532+ self ,
533+ tokens : Optional [torch .LongTensor ] = None , # tokens
534+ input_pos : Optional [
535+ torch .LongTensor
536+ ] = None , # Scalar tensor indicating size of window of the caches
537+ h : Optional [torch .FloatTensor ] = None , # embeddings
532538 ) -> torch .Tensor :
533539 if (tokens is None ) ^ (h is not None ):
534540 raise ValueError (
0 commit comments