1515import  enum 
1616from  copy  import  deepcopy 
1717from  typing  import  TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union 
18+ import  inspect 
19+ import  re 
1820
1921from  packaging  import  version 
2022from  transformers  import  AutoConfig , PretrainedConfig , PreTrainedModel , TFPreTrainedModel 
9597    LlamaModelPatcher ,
9698    LlavaImageEmbeddingModelPatcher ,
9799    LlavaQwen2ImageEmbeddingsModelPatcher ,
100+     MambaPatcher ,
98101    MiniCPM3Patcher ,
99102    MiniCPMModelPatcher ,
100103    MiniCPMVImageEmbeddingsModelPatcher ,
@@ -2880,3 +2883,126 @@ def patch_model_for_export(
28802883        self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] =  None 
28812884    ) ->  "ModelPatcher" :
28822885        return  DeepseekPatcher (self , model , model_kwargs = model_kwargs )
2886+ 
2887+ 
2888+ class  MambaCacheDummyInputGenerator (DummyInputGenerator ):
2889+     """ 
2890+     Generates dummy past_key_values inputs for seq2seq architectures. 
2891+     """ 
2892+ 
2893+     SUPPORTED_INPUT_NAMES  =  ("past_ssm_states" , "past_conv_states" , "cache_position" )
2894+ 
2895+     def  __init__ (
2896+         self ,
2897+         task : str ,
2898+         normalized_config ,
2899+         batch_size : int  =  DEFAULT_DUMMY_SHAPES ["batch_size" ],
2900+         sequence_length : int  =  DEFAULT_DUMMY_SHAPES ["sequence_length" ],
2901+         ** kwargs ,
2902+     ):
2903+         self .normalized_config  =  normalized_config 
2904+         self .batch_size  =  batch_size 
2905+         self .sequence_length  =  sequence_length 
2906+         self .intermediate_size  =  self .normalized_config .config .intermediate_size 
2907+         self .ssm_state_size  =  self .normalized_config .config .state_size 
2908+         self .conv_kernel_size  =  self .normalized_config .config .conv_kernel 
2909+ 
2910+ 
2911+     def  generate (self , input_name : str , framework : str  =  "pt" , int_dtype : str  =  "int64" , float_dtype : str  =  "fp32" ):
2912+         if  input_name  ==  "past_ssm_states" :
2913+             ssm_shape  =  [self .batch_size , self .intermediate_size , self .ssm_state_size ]
2914+             return  [self .random_float_tensor (ssm_shape , framework = framework , dtype = float_dtype ) for  _  in  range (self .normalized_config .num_layers )]
2915+ 
2916+         elif  input_name  ==  "past_conv_states" :
2917+             conv_shape  =  [self .batch_size , self .intermediate_size , self .conv_kernel_size ]
2918+             return  [self .random_float_tensor (conv_shape , framework = framework , dtype = float_dtype ) for  _  in  range (self .normalized_config .num_layers )]
2919+ 
2920+         elif  input_name  ==  "cache_position" :
2921+             return  self .random_int_tensor (
2922+                 shape = [self .conv_kernel_size ],
2923+                 max_value = self .sequence_length ,
2924+                 framework = framework ,
2925+                 dtype = int_dtype ,
2926+             )
2927+ 
2928+         raise  ValueError (f"Unsupported input name { input_name }  )
2929+ 
2930+ @register_in_tasks_manager ( 
2931+     "mamba" , * ["text-generation" , "text-generation-with-past" ], library_name = "transformers"  
2932+ ) 
2933+ class  MambaOpenVINOConfig (TextDecoderOnnxConfig ):
2934+     DUMMY_INPUT_GENERATOR_CLASSES  =  (DummyTextInputGenerator , MambaCacheDummyInputGenerator )
2935+     DUMMY_PKV_GENERATOR_CLASS  =  MambaCacheDummyInputGenerator 
2936+     NORMALIZED_CONFIG_CLASS  =  NormalizedTextConfig 
2937+ 
2938+     @property  
2939+     def  inputs (self ) ->  Dict [str , Dict [int , str ]]:
2940+         if  self .use_past_in_inputs :
2941+             common_inputs  =  {"input_ids" : {0 : "batch_size" , 1 : "sequence_length" }}
2942+             self .add_past_key_values (common_inputs , direction = "inputs" )
2943+             #common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} 
2944+             common_inputs ["cache_position" ] =  {0 : "cache_sequence_length" }
2945+         else :
2946+             common_inputs  =  {
2947+                 "input_ids" : {0 : "batch_size" , 1 : "sequence_length" },
2948+                 #"attention_mask": {0: "batch_size", 1: "sequence_length"}, 
2949+                 "cache_position" : {0 : "cache_sequence_length" }
2950+             }
2951+         return  common_inputs 
2952+ 
2953+     def  add_past_key_values (self , inputs_or_outputs : Dict [str , Dict [int , str ]], direction : str ):
2954+         """ 
2955+         Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. 
2956+ 
2957+         Args: 
2958+             inputs_or_outputs (`Dict[str, Dict[int, str]]`): 
2959+                 The mapping to fill. 
2960+             direction (`str`): 
2961+                 either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the 
2962+                 output mapping, this is important for axes naming. 
2963+         """ 
2964+         if  direction  not  in "inputs" , "outputs" ]:
2965+             raise  ValueError (f'direction must either be "inputs" or "outputs", but { direction }  )
2966+ 
2967+         if  direction  ==  "inputs" :
2968+             ssm_name  =  "past_ssm_states" 
2969+             conv_name  =  "past_conv_states" 
2970+         else :
2971+             ssm_name  =  "present_ssm_states" 
2972+             conv_name  =  "present_conv_states" 
2973+ 
2974+         for  i  in  range (self ._normalized_config .num_layers ):
2975+             inputs_or_outputs [f"{ ssm_name } { i }  ] =  {0 : "batch_size" }
2976+ 
2977+         for  i  in  range (self ._normalized_config .num_layers ):
2978+             inputs_or_outputs [f"{ conv_name } { i }  ] =  {0 : "batch_size" }
2979+     
2980+     def  patch_model_for_export (self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] =  None ):
2981+         return  MambaPatcher (self , model , model_kwargs )
2982+ 
2983+     def  generate_dummy_inputs (self , framework : str  =  "pt" , ** kwargs ):
2984+             dummy_inputs_generators  =  self ._create_dummy_input_generator_classes (** kwargs )
2985+ 
2986+             dummy_inputs  =  {}
2987+             input_names  =  [key  for  key  in  self .inputs .keys () if  not  key .startswith ("past_" )]
2988+             if  self .use_past_in_inputs  and  self .use_cache_branch  is  not False :
2989+                 input_names .extend (["past_ssm_states" , "past_conv_states" ])
2990+ 
2991+             for  input_name  in  input_names :
2992+                 input_was_inserted  =  False 
2993+                 for  dummy_input_gen  in  dummy_inputs_generators :
2994+                     if  dummy_input_gen .supports_input (input_name ):
2995+                         dummy_inputs [input_name ] =  self .overwrite_shape_and_generate_input (
2996+                             dummy_input_gen ,
2997+                             input_name ,
2998+                             framework ,
2999+                             input_shapes = kwargs ,
3000+                         )
3001+                         input_was_inserted  =  True 
3002+                         break 
3003+                 if  not  input_was_inserted :
3004+                     raise  RuntimeError (
3005+                         f'Could not generate dummy input for "{ input_name }  
3006+                     )
3007+ 
3008+             return  dummy_inputs 
0 commit comments