Skip to content

Commit efa68b4

Browse files
authored
Fix model patching for internlm2 (#814)
1 parent 101c7c2 commit efa68b4

File tree

1 file changed

+23
-13
lines changed

1 file changed

+23
-13
lines changed

optimum/exporters/openvino/model_patcher.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,10 +1073,13 @@ def rotate_half(x):
10731073
x2 = x[..., x.shape[-1] // 2 :]
10741074
return torch.cat((-x2, x1), dim=-1)
10751075

1076-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
1076+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
10771077
"""Applies Rotary Position Embedding to the query and key tensors."""
1078-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
1079-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
1078+
if position_ids is not None:
1079+
cos = cos[position_ids]
1080+
sin = sin[position_ids]
1081+
cos = cos.unsqueeze(unsqueeze_dim)
1082+
sin = sin.unsqueeze(unsqueeze_dim)
10801083
q_embed = (q * cos) + (rotate_half(q) * sin)
10811084
k_embed = (k * cos) + (rotate_half(k) * sin)
10821085
return q_embed, k_embed
@@ -1113,17 +1116,24 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
11131116
value_states = value_states.transpose(1, 2)
11141117

11151118
kv_seq_len = key_states.shape[-2]
1116-
if past_key_value is not None:
1117-
kv_seq_len += past_key_value[0].shape[-2]
1118-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1119-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1120-
1121-
if past_key_value is not None:
1122-
# reuse k, v, self_attention
1123-
key_states = torch.cat([past_key_value[0], key_states], dim=2)
1124-
value_states = torch.cat([past_key_value[1], value_states], dim=2)
1119+
is_legacy = not hasattr(self, "layer_idx")
1120+
1121+
if is_legacy:
1122+
if past_key_value is not None:
1123+
kv_seq_len += past_key_value[0].shape[-2]
1124+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1125+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1126+
if past_key_value is not None:
1127+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
1128+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
1129+
past_key_value = (key_states, value_states) if use_cache else None
1130+
else:
1131+
cos, sin = self.rotary_emb(value_states, position_ids)
1132+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
11251133

1126-
past_key_value = (key_states, value_states) if use_cache else None
1134+
if past_key_value is not None:
1135+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": kwargs.get("cache_position")}
1136+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
11271137

11281138
key_states = repeat_kv(key_states, self.num_key_value_groups)
11291139
value_states = repeat_kv(value_states, self.num_key_value_groups)

0 commit comments

Comments
 (0)