@@ -352,26 +352,18 @@ def forward(
352352 # View + expand + reshape bring num_kv_heads to num_heads for k and v
353353 # to match q.
354354
355- # k: [bsz, seq_len, n_kv, 1, h_d]
356- # v: [bsz, seq_len, n_kv, 1, h_d]
357- k = k .view (bsz , - 1 , self .num_kv_heads , 1 , self .head_dim )
358- v = v .view (bsz , - 1 , self .num_kv_heads , 1 , self .head_dim )
359-
360- # Expand the key and value tensors to have the same shape
361- # as the query tensor by copying values across the relevant dim
362- if self .num_heads != self .num_kv_heads :
363- k = k .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
364- v = v .expand (bsz , - 1 , self .num_kv_heads , self .q_per_kv , self .head_dim )
365-
366- # [bsz, s, n_h, h_d]
367- k = k .reshape (bsz , - 1 , self .num_heads , self .head_dim )
368- v = v .reshape (bsz , - 1 , self .num_heads , self .head_dim )
369-
370355 # [bsz, n_h, s, h_d]
371356 q = q .transpose (1 , 2 )
372357 k = k .transpose (1 , 2 )
373358 v = v .transpose (1 , 2 )
374359
360+ # Expand the key and value tensors to have the same shape
361+ # as the query tensor by copying values across the relevant dim
362+ if self .num_heads != self .num_kv_heads :
363+ expand_shape = (- 1 , - 1 , self .q_per_kv , - 1 , - 1 )
364+ k = k .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
365+ v = v .unsqueeze (2 ).expand (expand_shape ).flatten (1 , 2 )
366+
375367 output = self ._attention_fn (
376368 q ,
377369 k ,
0 commit comments