Skip to content

Commit 16c27ca

Browse files
authored
support minicpm3 (#1029)
* support minicpm3 * model patcher * add test * update readme * Update tests/openvino/test_modeling.py * Update tests/openvino/test_modeling.py * Update optimum/exporters/openvino/model_patcher.py
1 parent 860bdf8 commit 16c27ca

File tree

5 files changed

+200
-0
lines changed

5 files changed

+200
-0
lines changed

docs/source/openvino/models.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Here is the list of the supported architectures :
7070
- MT5
7171
- Marian
7272
- MiniCPM
73+
- MiniCPM3
7374
- Mistral
7475
- Mixtral
7576
- MobileBert

optimum/exporters/openvino/model_configs.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
LlamaModelPatcher,
8080
LlavaImageEmbeddingModelPatcher,
8181
LlavaQwen2ImageEmbeddingsModelPatcher,
82+
MiniCPM3Patcher,
8283
MiniCPMVImageEmbeddingsModelPatcher,
8384
MiniCPMVResamplerModelPatcher,
8485
MistralModelPatcher,
@@ -192,6 +193,60 @@ class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
192193
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
193194

194195

196+
class OVMiniCPM3DummyPastKeyValuesGenerator(MistralDummyPastKeyValuesGenerator):
197+
def __init__(
198+
self,
199+
task: str,
200+
normalized_config: NormalizedTextConfig,
201+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
202+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
203+
random_batch_size_range: Optional[Tuple[int, int]] = None,
204+
random_sequence_length_range: Optional[Tuple[int, int]] = None,
205+
**kwargs,
206+
):
207+
super().__init__(
208+
task=task,
209+
normalized_config=normalized_config,
210+
batch_size=batch_size,
211+
sequence_length=sequence_length,
212+
random_batch_size_range=random_batch_size_range,
213+
random_sequence_length_range=random_sequence_length_range,
214+
**kwargs,
215+
)
216+
self.v_head_dim = getattr(normalized_config, "v_head_dim", self.hidden_size // self.num_attention_heads)
217+
self.k_head_dim = normalized_config.qk_nope_head_dim + normalized_config.qk_rope_head_dim
218+
219+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
220+
v_shape = (
221+
self.batch_size,
222+
self.num_key_value_heads,
223+
self.sequence_length,
224+
self.v_head_dim,
225+
)
226+
k_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.k_head_dim)
227+
return [
228+
(
229+
self.random_float_tensor(k_shape, framework=framework, dtype=float_dtype),
230+
self.random_float_tensor(v_shape, framework=framework, dtype=float_dtype),
231+
)
232+
for _ in range(self.num_layers)
233+
]
234+
235+
236+
@register_in_tasks_manager("minicpm3", *["text-generation", "text-generation-with-past"], library_name="transformers")
237+
class MiniCPM3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
238+
DEFAULT_ONNX_OPSET = 14
239+
240+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, OVMiniCPM3DummyPastKeyValuesGenerator)
241+
DUMMY_PKV_GENERATOR_CLASS = OVMiniCPM3DummyPastKeyValuesGenerator
242+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
243+
244+
def patch_model_for_export(
245+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
246+
) -> ModelPatcher:
247+
return MiniCPM3Patcher(self, model, model_kwargs=model_kwargs)
248+
249+
195250
@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers")
196251
class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
197252
DEFAULT_ONNX_OPSET = 14

optimum/exporters/openvino/model_patcher.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3237,3 +3237,144 @@ def __init__(
32373237
def __exit__(self, exc_type, exc_value, traceback):
32383238
super().__exit__(exc_type, exc_value, traceback)
32393239
self._model.forward = self._model.__orig_forward
3240+
3241+
3242+
def minicpm3_attn_forward(
3243+
self,
3244+
hidden_states: torch.Tensor,
3245+
attention_mask: Optional[torch.Tensor] = None,
3246+
position_ids: Optional[torch.LongTensor] = None,
3247+
past_key_value=None,
3248+
output_attentions: bool = False,
3249+
use_cache: bool = False,
3250+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
3251+
def rotate_half(x):
3252+
"""Rotates half the hidden dims of the input."""
3253+
x1 = x[..., : x.shape[-1] // 2]
3254+
x2 = x[..., x.shape[-1] // 2 :]
3255+
return torch.cat((-x2, x1), dim=-1)
3256+
3257+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
3258+
"""Applies Rotary Position Embedding to the query and key tensors.
3259+
Args:
3260+
q (`torch.Tensor`): The query tensor.
3261+
k (`torch.Tensor`): The key tensor.
3262+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
3263+
sin (`torch.Tensor`): The sine part of the rotary embedding.
3264+
position_ids (`torch.Tensor`):
3265+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
3266+
used to pass offsetted position ids when working with a KV-cache.
3267+
unsqueeze_dim (`int`, *optional*, defaults to 1):
3268+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
3269+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
3270+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
3271+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
3272+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
3273+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
3274+
Returns:
3275+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
3276+
"""
3277+
orig_dtype = k.dtype
3278+
cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
3279+
sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim]
3280+
q_fp32 = q.to(dtype=torch.float32, device=q.device)
3281+
k_fp32 = k.to(dtype=torch.float32, device=k.device)
3282+
q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin)
3283+
k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin)
3284+
return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)
3285+
3286+
if output_attentions:
3287+
return self._orig_forward(
3288+
hidden_states=hidden_states,
3289+
attention_mask=attention_mask,
3290+
position_ids=position_ids,
3291+
past_key_value=past_key_value,
3292+
output_attentions=output_attentions,
3293+
use_cache=use_cache,
3294+
)
3295+
3296+
bsz, q_len, _ = hidden_states.shape
3297+
3298+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
3299+
q = q.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.q_head_dim).transpose(1, 2)
3300+
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
3301+
3302+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
3303+
compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
3304+
k_pe = k_pe.view(hidden_states.shape[0], hidden_states.shape[1], 1, self.qk_rope_head_dim).transpose(1, 2)
3305+
kv = (
3306+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
3307+
.view(hidden_states.shape[0], hidden_states.shape[1], self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
3308+
.transpose(1, 2)
3309+
)
3310+
3311+
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
3312+
3313+
kv_seq_len = value_states.shape[-2]
3314+
if past_key_value is not None:
3315+
if self.layer_idx is None:
3316+
raise ValueError(
3317+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
3318+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
3319+
"with a layer index."
3320+
)
3321+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
3322+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
3323+
3324+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
3325+
3326+
# Difference with original code, k_pe.new_empty create constant tensor in torchscript
3327+
query_states = torch.concat([q_nope, q_pe], dim=-1)
3328+
# query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3329+
# query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
3330+
# query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
3331+
key_states = torch.concat([k_nope, k_pe.expand(-1, self.num_heads, -1, -1)], dim=-1)
3332+
# key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
3333+
# key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
3334+
# key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
3335+
if past_key_value is not None:
3336+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
3337+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
3338+
3339+
if attention_mask is not None:
3340+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
3341+
raise ValueError(
3342+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
3343+
)
3344+
3345+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
3346+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
3347+
if query_states.device.type == "cuda" and attention_mask is not None:
3348+
query_states = query_states.contiguous()
3349+
key_states = key_states.contiguous()
3350+
value_states = value_states.contiguous()
3351+
3352+
attn_output = torch.nn.functional.scaled_dot_product_attention(
3353+
query_states,
3354+
key_states,
3355+
value_states,
3356+
attn_mask=attention_mask,
3357+
dropout_p=self.attention_dropout if self.training else 0.0,
3358+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
3359+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
3360+
)
3361+
3362+
attn_output = attn_output.transpose(1, 2).contiguous()
3363+
attn_output = attn_output.reshape(hidden_states.shape[0], hidden_states.shape[1], self.hidden_size)
3364+
3365+
attn_output = self.o_proj(attn_output)
3366+
3367+
return attn_output, None, past_key_value
3368+
3369+
3370+
class MiniCPM3Patcher(DecoderModelPatcher):
3371+
def __enter__(self):
3372+
super().__enter__()
3373+
for block in self._model.model.layers:
3374+
block.self_attn._orig_forward = block.self_attn.forward
3375+
block.self_attn.forward = types.MethodType(minicpm3_attn_forward, block.self_attn)
3376+
3377+
def __exit__(self, exc_type, exc_value, traceback):
3378+
super().__exit__(exc_type, exc_value, traceback)
3379+
for block in self._model.model.layers:
3380+
block.self_attn.forward = block.self_attn._orig_forward

tests/openvino/test_modeling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
914914
"arctic",
915915
"exaone",
916916
"mistral-nemo",
917+
"minicpm3",
917918
)
918919

919920
GENERATION_LENGTH = 100
@@ -935,6 +936,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
935936
"glm4",
936937
"exaone",
937938
"decilm",
939+
"minicpm3",
938940
)
939941

940942
@parameterized.expand(SUPPORTED_ARCHITECTURES)

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"marian": "sshleifer/tiny-marian-en-de",
8787
"mbart": "hf-internal-testing/tiny-random-mbart",
8888
"minicpm": "katuni4ka/tiny-random-minicpm",
89+
"minicpm3": "katuni4ka/tiny-random-minicpm3",
8990
"minicpmv": "katuni4ka/tiny-random-minicpmv-2_6",
9091
"mistral": "echarlaix/tiny-random-mistral",
9192
"mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",

0 commit comments

Comments
 (0)