@@ -1073,10 +1073,13 @@ def rotate_half(x):
10731073 x2 = x [..., x .shape [- 1 ] // 2 :]
10741074 return torch .cat ((- x2 , x1 ), dim = - 1 )
10751075
1076- def apply_rotary_pos_emb (q , k , cos , sin , position_ids , unsqueeze_dim = 1 ):
1076+ def apply_rotary_pos_emb (q , k , cos , sin , position_ids = None , unsqueeze_dim = 1 ):
10771077 """Applies Rotary Position Embedding to the query and key tensors."""
1078- cos = cos [position_ids ].unsqueeze (unsqueeze_dim )
1079- sin = sin [position_ids ].unsqueeze (unsqueeze_dim )
1078+ if position_ids is not None :
1079+ cos = cos [position_ids ]
1080+ sin = sin [position_ids ]
1081+ cos = cos .unsqueeze (unsqueeze_dim )
1082+ sin = sin .unsqueeze (unsqueeze_dim )
10801083 q_embed = (q * cos ) + (rotate_half (q ) * sin )
10811084 k_embed = (k * cos ) + (rotate_half (k ) * sin )
10821085 return q_embed , k_embed
@@ -1113,17 +1116,24 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
11131116 value_states = value_states .transpose (1 , 2 )
11141117
11151118 kv_seq_len = key_states .shape [- 2 ]
1116- if past_key_value is not None :
1117- kv_seq_len += past_key_value [0 ].shape [- 2 ]
1118- cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
1119- query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
1120-
1121- if past_key_value is not None :
1122- # reuse k, v, self_attention
1123- key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
1124- value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
1119+ is_legacy = not hasattr (self , "layer_idx" )
1120+
1121+ if is_legacy :
1122+ if past_key_value is not None :
1123+ kv_seq_len += past_key_value [0 ].shape [- 2 ]
1124+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
1125+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
1126+ if past_key_value is not None :
1127+ key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
1128+ value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
1129+ past_key_value = (key_states , value_states ) if use_cache else None
1130+ else :
1131+ cos , sin = self .rotary_emb (value_states , position_ids )
1132+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin )
11251133
1126- past_key_value = (key_states , value_states ) if use_cache else None
1134+ if past_key_value is not None :
1135+ cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : kwargs .get ("cache_position" )}
1136+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
11271137
11281138 key_states = repeat_kv (key_states , self .num_key_value_groups )
11291139 value_states = repeat_kv (value_states , self .num_key_value_groups )
0 commit comments