@@ -246,6 +246,7 @@ def forward(
246246 # x has shape [b, s_x, d]
247247 # y has shape [b, s_y, d]
248248 b , s_x , _ = x .shape
249+ s_y = y .shape [1 ] if y is not None else 0
249250
250251 # q has shape [b, s_x, num_heads * head_dim]
251252 q = self .q_proj (x )
@@ -262,9 +263,16 @@ def forward(
262263 if self .q_norm is not None :
263264 q = self .q_norm (q )
264265
265- def calculate_kv (y ):
266+ if y is None :
267+ if self .kv_cache is None :
268+ raise ValueError (
269+ "Must provide y input or use kv_cache to enable streaming decoding"
270+ )
271+ k = self .kv_cache .k_cache
272+ v = self .kv_cache .v_cache
273+ else :
266274 # Update k and v shape, positional embeddings, and normalization
267- s_y = y . shape [ 1 ]
275+
268276 # k has shape [b, s_y, num_kv_heads * head_dim]
269277 # v has shape [b, s_y, num_kv_heads * head_dim]
270278 k = self .k_proj (y )
@@ -280,37 +288,12 @@ def calculate_kv(y):
280288 # Normalize k
281289 if self .k_norm is not None :
282290 k = self .k_norm (k )
283- return k , v
284-
285- def true_fn (y ):
286- kv_cache = self .kv_cache .clone ()
287- return kv_cache .k_cache , kv_cache .v_cache , kv_cache .cache_pos
288291
289- def false_fn (y ):
290- k , v = calculate_kv (y )
291- kv_cache = self .kv_cache .clone ()
292- kv_cache .update (k , v )
293- return kv_cache .k_cache , kv_cache .v_cache , kv_cache .cache_pos
294-
295- # If kv cache is None, we expect y to be provided
296- if self .kv_cache is None :
297- assert (
298- y is not None
299- ), "Must provide y input or use kv_cache to enable streaming decoding"
300- k , v = calculate_kv (y )
301- else :
302- # Expecting the k, v returning here to be the same size of self.kv_cache
303- # In eager, we expect this predicate to specialize. In export, this will
304- # become a SymBool so it's not specialized.
305- k , v , cache_pos = torch .cond (
306- torch .isnan (y ).all ().item (), true_fn , false_fn , (y ,)
307- )
308292 # Update key-value cache
309- self .kv_cache .k_cache .copy_ (k )
310- self .kv_cache .v_cache .copy_ (v )
311- self .kv_cache .cache_pos .copy_ (cache_pos )
293+ if self .kv_cache is not None and self .cache_enabled :
294+ k , v = self .kv_cache .update (k , v )
312295
313- output = self ._sdpa (q , k , v , b , s_x , mask = mask )
296+ output = self ._sdpa (q , k , v , b , s_x )
314297 return self .output_proj (output )
315298
316299
@@ -352,17 +335,25 @@ def forward(
352335 # View + expand + reshape bring num_kv_heads to num_heads for k and v
353336 # to match q.
354337
355- # [bsz, n_h, s , h_d]
356- q = q . transpose ( 1 , 2 )
357- k = k .transpose ( 1 , 2 )
358- v = v .transpose ( 1 , 2 )
338+ # k: [bsz, seq_len, n_kv, 1 , h_d]
339+ # v: [bsz, seq_len, n_kv, 1, h_d]
340+ k = k .view ( bsz , - 1 , self . num_kv_heads , 1 , self . head_dim )
341+ v = v .view ( bsz , - 1 , self . num_kv_heads , 1 , self . head_dim )
359342
360343 # Expand the key and value tensors to have the same shape
361344 # as the query tensor by copying values across the relevant dim
362345 if self .num_heads != self .num_kv_heads :
363- expand_shape = (- 1 , - 1 , self .q_per_kv , - 1 , - 1 )
364- k = k .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
365- v = v .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
346+ k = k .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
347+ v = v .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
348+
349+ # [bsz, s, n_h, h_d]
350+ k = k .reshape (bsz , - 1 , self .num_heads , self .head_dim )
351+ v = v .reshape (bsz , - 1 , self .num_heads , self .head_dim )
352+
353+ # [bsz, n_h, s, h_d]
354+ q = q .transpose (1 , 2 )
355+ k = k .transpose (1 , 2 )
356+ v = v .transpose (1 , 2 )
366357
367358 output = self ._attention_fn (
368359 q ,
0 commit comments