|
117 | 117 | LlavaNextVideoImageEmbeddingModelPatcher, |
118 | 118 | LlavaQwen2ImageEmbeddingsModelPatcher, |
119 | 119 | MairaImageEmbeddingModelPatcher, |
| 120 | + MambaPatcher, |
120 | 121 | MarianModelPatcher, |
121 | 122 | MarianStatefulSeq2SeqDecoderPatcher, |
122 | 123 | MiniCPM3Patcher, |
@@ -4382,3 +4383,128 @@ def patch_model_for_export( |
4382 | 4383 | if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS: |
4383 | 4384 | return super().patch_model_for_export(model, model_kwargs) |
4384 | 4385 | return Llama4ImageEmbeddingsModelPatcher(self, model, model_kwargs) |
| 4386 | + |
| 4387 | + |
| 4388 | +class MambaCacheDummyInputGenerator(DummyInputGenerator): |
| 4389 | + """ |
| 4390 | + Generates dummy past_ssm_states, past_conv_states and cache_position inputs for Mamba architectures. |
| 4391 | + """ |
| 4392 | + |
| 4393 | + SUPPORTED_INPUT_NAMES = ("cache_params", "cache_position") |
| 4394 | + |
| 4395 | + def __init__( |
| 4396 | + self, |
| 4397 | + task: str, |
| 4398 | + normalized_config, |
| 4399 | + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], |
| 4400 | + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], |
| 4401 | + **kwargs, |
| 4402 | + ): |
| 4403 | + self.normalized_config = normalized_config |
| 4404 | + self.batch_size = batch_size |
| 4405 | + self.sequence_length = sequence_length |
| 4406 | + self.intermediate_size = self.normalized_config.config.intermediate_size |
| 4407 | + self.ssm_state_size = self.normalized_config.config.state_size |
| 4408 | + self.conv_kernel_size = self.normalized_config.config.conv_kernel |
| 4409 | + |
| 4410 | + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): |
| 4411 | + if input_name == "cache_params": |
| 4412 | + ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size] |
| 4413 | + conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size] |
| 4414 | + return [ |
| 4415 | + ( |
| 4416 | + self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype), |
| 4417 | + self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype), |
| 4418 | + ) |
| 4419 | + for _ in range(self.normalized_config.num_layers) |
| 4420 | + ] |
| 4421 | + elif input_name == "cache_position": |
| 4422 | + return self.random_int_tensor( |
| 4423 | + shape=[self.conv_kernel_size], |
| 4424 | + max_value=self.sequence_length, |
| 4425 | + framework=framework, |
| 4426 | + dtype=int_dtype, |
| 4427 | + ) |
| 4428 | + |
| 4429 | + raise ValueError(f"Unsupported input name {input_name}") |
| 4430 | + |
| 4431 | + |
| 4432 | +@register_in_tasks_manager( |
| 4433 | + "falcon_mamba", *["text-generation", "text-generation-with-past"], library_name="transformers" |
| 4434 | +) |
| 4435 | +@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 4436 | +class MambaOpenVINOConfig(TextDecoderOnnxConfig): |
| 4437 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator) |
| 4438 | + DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator |
| 4439 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 4440 | + MIN_TRANSFORMERS_VERSION = version.parse("4.43.0") |
| 4441 | + |
| 4442 | + @property |
| 4443 | + def inputs(self) -> Dict[str, Dict[int, str]]: |
| 4444 | + common_inputs = { |
| 4445 | + "input_ids": {0: "batch_size", 1: "sequence_length"}, |
| 4446 | + "attention_mask": {0: "batch_size", 1: "sequence_length"}, |
| 4447 | + "cache_position": {0: "cache_sequence_length"}, |
| 4448 | + } |
| 4449 | + if self.use_past_in_inputs: |
| 4450 | + self.add_past_key_values(common_inputs, direction="inputs") |
| 4451 | + return common_inputs |
| 4452 | + |
| 4453 | + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): |
| 4454 | + """ |
| 4455 | + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. |
| 4456 | +
|
| 4457 | + Args: |
| 4458 | + inputs_or_outputs (`Dict[str, Dict[int, str]]`): |
| 4459 | + The mapping to fill. |
| 4460 | + direction (`str`): |
| 4461 | + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the |
| 4462 | + output mapping, this is important for axes naming. |
| 4463 | + """ |
| 4464 | + if direction not in ["inputs", "outputs"]: |
| 4465 | + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') |
| 4466 | + |
| 4467 | + if direction == "inputs": |
| 4468 | + ssm_conv_states_name = "cache_params.past" |
| 4469 | + else: |
| 4470 | + ssm_conv_states_name = "cache_params.present" |
| 4471 | + |
| 4472 | + for i in range(self._normalized_config.num_layers): |
| 4473 | + # [batch_size, d_state, d_model] |
| 4474 | + inputs_or_outputs[f"{ssm_conv_states_name}.ssm.{i}"] = {0: "batch_size"} |
| 4475 | + # [batch_size, conv_kernel_size - 1, d_model] |
| 4476 | + inputs_or_outputs[f"{ssm_conv_states_name}.conv.{i}"] = {0: "batch_size"} |
| 4477 | + |
| 4478 | + def patch_model_for_export( |
| 4479 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 4480 | + ): |
| 4481 | + return MambaPatcher(self, model, model_kwargs) |
| 4482 | + |
| 4483 | + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): |
| 4484 | + # need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states |
| 4485 | + # which we separate and call them as past_ssm_states and past_conv_states |
| 4486 | + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) |
| 4487 | + |
| 4488 | + dummy_inputs = {} |
| 4489 | + input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")] |
| 4490 | + if self.use_past_in_inputs: |
| 4491 | + input_names.extend(["cache_params"]) |
| 4492 | + |
| 4493 | + for input_name in input_names: |
| 4494 | + input_was_inserted = False |
| 4495 | + for dummy_input_gen in dummy_inputs_generators: |
| 4496 | + if dummy_input_gen.supports_input(input_name): |
| 4497 | + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( |
| 4498 | + dummy_input_gen, |
| 4499 | + input_name, |
| 4500 | + framework, |
| 4501 | + input_shapes=kwargs, |
| 4502 | + ) |
| 4503 | + input_was_inserted = True |
| 4504 | + break |
| 4505 | + if not input_was_inserted: |
| 4506 | + raise RuntimeError( |
| 4507 | + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' |
| 4508 | + ) |
| 4509 | + |
| 4510 | + return dummy_inputs |
0 commit comments