|  | 
| 95 | 95 |     LlamaModelPatcher, | 
| 96 | 96 |     LlavaImageEmbeddingModelPatcher, | 
| 97 | 97 |     LlavaQwen2ImageEmbeddingsModelPatcher, | 
|  | 98 | +    MambaPatcher, | 
| 98 | 99 |     MiniCPM3Patcher, | 
| 99 | 100 |     MiniCPMModelPatcher, | 
| 100 | 101 |     MiniCPMVImageEmbeddingsModelPatcher, | 
| @@ -2880,3 +2881,132 @@ def patch_model_for_export( | 
| 2880 | 2881 |         self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None | 
| 2881 | 2882 |     ) -> "ModelPatcher": | 
| 2882 | 2883 |         return DeepseekPatcher(self, model, model_kwargs=model_kwargs) | 
|  | 2884 | + | 
|  | 2885 | + | 
|  | 2886 | +class MambaCacheDummyInputGenerator(DummyInputGenerator): | 
|  | 2887 | +    """ | 
|  | 2888 | +    Generates dummy past_key_values inputs for seq2seq architectures. | 
|  | 2889 | +    """ | 
|  | 2890 | + | 
|  | 2891 | +    SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position") | 
|  | 2892 | + | 
|  | 2893 | +    def __init__( | 
|  | 2894 | +        self, | 
|  | 2895 | +        task: str, | 
|  | 2896 | +        normalized_config, | 
|  | 2897 | +        batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | 
|  | 2898 | +        sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], | 
|  | 2899 | +        **kwargs, | 
|  | 2900 | +    ): | 
|  | 2901 | +        self.normalized_config = normalized_config | 
|  | 2902 | +        self.batch_size = batch_size | 
|  | 2903 | +        self.sequence_length = sequence_length | 
|  | 2904 | +        self.intermediate_size = self.normalized_config.config.intermediate_size | 
|  | 2905 | +        self.ssm_state_size = self.normalized_config.config.state_size | 
|  | 2906 | +        self.conv_kernel_size = self.normalized_config.config.conv_kernel | 
|  | 2907 | + | 
|  | 2908 | +    def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | 
|  | 2909 | +        if input_name == "past_ssm_states": | 
|  | 2910 | +            ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size] | 
|  | 2911 | +            return [ | 
|  | 2912 | +                self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype) | 
|  | 2913 | +                for _ in range(self.normalized_config.num_layers) | 
|  | 2914 | +            ] | 
|  | 2915 | + | 
|  | 2916 | +        elif input_name == "past_conv_states": | 
|  | 2917 | +            conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size] | 
|  | 2918 | +            return [ | 
|  | 2919 | +                self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype) | 
|  | 2920 | +                for _ in range(self.normalized_config.num_layers) | 
|  | 2921 | +            ] | 
|  | 2922 | + | 
|  | 2923 | +        elif input_name == "cache_position": | 
|  | 2924 | +            return self.random_int_tensor( | 
|  | 2925 | +                shape=[self.conv_kernel_size], | 
|  | 2926 | +                max_value=self.sequence_length, | 
|  | 2927 | +                framework=framework, | 
|  | 2928 | +                dtype=int_dtype, | 
|  | 2929 | +            ) | 
|  | 2930 | + | 
|  | 2931 | +        raise ValueError(f"Unsupported input name {input_name}") | 
|  | 2932 | + | 
|  | 2933 | + | 
|  | 2934 | +@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers") | 
|  | 2935 | +class MambaOpenVINOConfig(TextDecoderOnnxConfig): | 
|  | 2936 | +    DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator) | 
|  | 2937 | +    DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator | 
|  | 2938 | +    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig | 
|  | 2939 | + | 
|  | 2940 | +    @property | 
|  | 2941 | +    def inputs(self) -> Dict[str, Dict[int, str]]: | 
|  | 2942 | +        if self.use_past_in_inputs: | 
|  | 2943 | +            common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} | 
|  | 2944 | +            self.add_past_key_values(common_inputs, direction="inputs") | 
|  | 2945 | +            # common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} | 
|  | 2946 | +            common_inputs["cache_position"] = {0: "cache_sequence_length"} | 
|  | 2947 | +        else: | 
|  | 2948 | +            common_inputs = { | 
|  | 2949 | +                "input_ids": {0: "batch_size", 1: "sequence_length"}, | 
|  | 2950 | +                # "attention_mask": {0: "batch_size", 1: "sequence_length"}, | 
|  | 2951 | +                "cache_position": {0: "cache_sequence_length"}, | 
|  | 2952 | +            } | 
|  | 2953 | +        return common_inputs | 
|  | 2954 | + | 
|  | 2955 | +    def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): | 
|  | 2956 | +        """ | 
|  | 2957 | +        Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. | 
|  | 2958 | +
 | 
|  | 2959 | +        Args: | 
|  | 2960 | +            inputs_or_outputs (`Dict[str, Dict[int, str]]`): | 
|  | 2961 | +                The mapping to fill. | 
|  | 2962 | +            direction (`str`): | 
|  | 2963 | +                either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the | 
|  | 2964 | +                output mapping, this is important for axes naming. | 
|  | 2965 | +        """ | 
|  | 2966 | +        if direction not in ["inputs", "outputs"]: | 
|  | 2967 | +            raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') | 
|  | 2968 | + | 
|  | 2969 | +        if direction == "inputs": | 
|  | 2970 | +            ssm_name = "past_ssm_states" | 
|  | 2971 | +            conv_name = "past_conv_states" | 
|  | 2972 | +        else: | 
|  | 2973 | +            ssm_name = "present_ssm_states" | 
|  | 2974 | +            conv_name = "present_conv_states" | 
|  | 2975 | + | 
|  | 2976 | +        for i in range(self._normalized_config.num_layers): | 
|  | 2977 | +            inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"} | 
|  | 2978 | + | 
|  | 2979 | +        for i in range(self._normalized_config.num_layers): | 
|  | 2980 | +            inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"} | 
|  | 2981 | + | 
|  | 2982 | +    def patch_model_for_export( | 
|  | 2983 | +        self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None | 
|  | 2984 | +    ): | 
|  | 2985 | +        return MambaPatcher(self, model, model_kwargs) | 
|  | 2986 | + | 
|  | 2987 | +    def generate_dummy_inputs(self, framework: str = "pt", **kwargs): | 
|  | 2988 | +        dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) | 
|  | 2989 | + | 
|  | 2990 | +        dummy_inputs = {} | 
|  | 2991 | +        input_names = [key for key in self.inputs.keys() if not key.startswith("past_")] | 
|  | 2992 | +        if self.use_past_in_inputs and self.use_cache_branch is not False: | 
|  | 2993 | +            input_names.extend(["past_ssm_states", "past_conv_states"]) | 
|  | 2994 | + | 
|  | 2995 | +        for input_name in input_names: | 
|  | 2996 | +            input_was_inserted = False | 
|  | 2997 | +            for dummy_input_gen in dummy_inputs_generators: | 
|  | 2998 | +                if dummy_input_gen.supports_input(input_name): | 
|  | 2999 | +                    dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( | 
|  | 3000 | +                        dummy_input_gen, | 
|  | 3001 | +                        input_name, | 
|  | 3002 | +                        framework, | 
|  | 3003 | +                        input_shapes=kwargs, | 
|  | 3004 | +                    ) | 
|  | 3005 | +                    input_was_inserted = True | 
|  | 3006 | +                    break | 
|  | 3007 | +            if not input_was_inserted: | 
|  | 3008 | +                raise RuntimeError( | 
|  | 3009 | +                    f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' | 
|  | 3010 | +                ) | 
|  | 3011 | + | 
|  | 3012 | +        return dummy_inputs | 
0 commit comments