8989 Phi3ModelPatcher ,
9090 Phi3VisionImageEmbeddingsPatcher ,
9191 QwenModelPatcher ,
92+ Qwen2VLLanguageModelPatcher ,
93+ Qwen2VLVisionEmbMergerPatcher ,
9294 RotaryEmbPatcher ,
9395 UpdateCausalMaskModelPatcher ,
9496 XverseModelPatcher ,
@@ -106,9 +108,13 @@ def init_model_configs():
106108 "transformers" ,
107109 "LlavaNextForConditionalGeneration" ,
108110 )
109- TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS [
110- "image-text-to-text"
111- ] = TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["text-generation" ]
111+ TasksManager ._CUSTOM_CLASSES [("pt" , "qwen2-vl" , "image-text-to-text" )] = (
112+ "transformers" ,
113+ "Qwen2VLForConditionalGeneration" ,
114+ )
115+ TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["image-text-to-text" ] = (
116+ TasksManager ._TRANSFORMERS_TASKS_TO_MODEL_LOADERS ["text-generation" ]
117+ )
112118
113119 supported_model_types = [
114120 "_SUPPORTED_MODEL_TYPE" ,
@@ -1288,18 +1294,26 @@ def patch_model_for_export(
12881294
12891295
12901296class LMInputEmbedsConfigHelper (TextDecoderWithPositionIdsOnnxConfig ):
1291- def __init__ (self , export_config ):
1297+ def __init__ (self , export_config , patcher_cls = None , dummy_input_generator = None , inputs_update = None ):
12921298 self .orig_export_config = export_config
1299+ if dummy_input_generator is not None :
1300+ export_config .DUMMY_INPUT_GENERATOR_CLASSES = (
1301+ dummy_input_generator ,
1302+ ) + export_config .DUMMY_INPUT_GENERATOR_CLASSES
12931303 self .DUMMY_INPUT_GENERATOR_CLASSES = export_config .DUMMY_INPUT_GENERATOR_CLASSES
12941304 self .DEFAULT_ONNX_OPSET = export_config .DEFAULT_ONNX_OPSET
12951305 self .DUMMY_PKV_GENERATOR_CLASS = export_config .DUMMY_PKV_GENERATOR_CLASS
12961306 self ._config = export_config ._config
12971307 self ._normalized_config = export_config ._normalized_config
12981308 self .use_past = export_config .use_past
1309+ self .patcher_cls = patcher_cls
1310+ self .input_info_upd = inputs_update
12991311
13001312 def patch_model_for_export (
13011313 self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
13021314 ) -> "ModelPatcher" :
1315+ if self .patcher_cls is not None :
1316+ return self .patcher_cls (self , model , model_kwargs = model_kwargs )
13031317 # Refer to DecoderModelPatcher.
13041318 return self .orig_export_config .patch_model_for_export (model , model_kwargs = model_kwargs )
13051319
@@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
13121326 orig_inputs = self .orig_export_config .inputs
13131327 input_ids_config = orig_inputs .pop ("input_ids" )
13141328 orig_inputs ["inputs_embeds" ] = input_ids_config
1329+ if self .input_info_upd is not None :
1330+ orig_inputs .update (self .input_info_upd )
13151331 return orig_inputs
13161332
13171333 def generate_dummy_inputs (self , framework : str = "pt" , ** kwargs ):
@@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
13831399 return export_config
13841400
13851401
1386- def get_vlm_text_generation_config (model_type , model_config , int_dtype , float_dtype ):
1402+ def get_vlm_text_generation_config (
1403+ model_type ,
1404+ model_config ,
1405+ int_dtype ,
1406+ float_dtype ,
1407+ model_patcher = None ,
1408+ dummy_input_generator = None ,
1409+ inputs_update = None ,
1410+ ):
13871411 internal_export_config = get_vlm_internal_text_generation_config (model_type , model_config , int_dtype , float_dtype )
1388- export_config = LMInputEmbedsConfigHelper (internal_export_config )
1412+ export_config = LMInputEmbedsConfigHelper (
1413+ internal_export_config ,
1414+ patcher_cls = model_patcher ,
1415+ dummy_input_generator = dummy_input_generator ,
1416+ inputs_update = inputs_update ,
1417+ )
13891418 export_config ._normalized_config = internal_export_config ._normalized_config
13901419 return export_config
13911420
@@ -1820,9 +1849,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
18201849 img_ids_height = self .height // 2
18211850 img_ids_width = self .width // 2
18221851 return self .random_int_tensor (
1823- [self .batch_size , img_ids_height * img_ids_width , 3 ]
1824- if is_diffusers_version ("<" , "0.31.0" )
1825- else [img_ids_height * img_ids_width , 3 ],
1852+ (
1853+ [self .batch_size , img_ids_height * img_ids_width , 3 ]
1854+ if is_diffusers_version ("<" , "0.31.0" )
1855+ else [img_ids_height * img_ids_width , 3 ]
1856+ ),
18261857 min_value = 0 ,
18271858 max_value = min (img_ids_height , img_ids_width ),
18281859 framework = framework ,
@@ -2259,3 +2290,218 @@ def patch_model_for_export(
22592290 if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
22602291 return Phi3VisionImageEmbeddingsPatcher (self , model , model_kwargs )
22612292 return super ().patch_model_for_export (model , model_kwargs )
2293+
2294+
2295+ class DummyQwen2VLLMInputGenerator (DummyTextInputGenerator ):
2296+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2297+ generated_input = super ().generate (input_name , framework , int_dtype , float_dtype )
2298+ if input_name == "position_ids" :
2299+ return generated_input .unsqueeze (0 ).expand (3 , - 1 , - 1 )
2300+ return generated_input
2301+
2302+
2303+ class DummyQwen2VLVisionEMbedInputGenerator (DummyVisionInputGenerator ):
2304+ SUPPORTED_INPUT_NAMES = ("hidden_states" ,)
2305+
2306+ def __init__ (
2307+ self ,
2308+ task : str ,
2309+ normalized_config : NormalizedVisionConfig ,
2310+ batch_size : int = 1 ,
2311+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
2312+ width : int = 420 ,
2313+ height : int = 420 ,
2314+ ** kwargs ,
2315+ ):
2316+ self .batch_size = batch_size
2317+ self .height = height
2318+ self .width = width
2319+ self .num_channels = num_channels
2320+ self .temporal_patch_size = normalized_config .config .temporal_patch_size
2321+ self .patch_size = normalized_config .config .patch_size
2322+
2323+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2324+ grid_h , grid_w = self .height // self .patch_size , self .width // self .patch_size
2325+ grid_t = self .batch_size
2326+ shape = [
2327+ grid_t * grid_h * grid_w ,
2328+ self .num_channels * self .temporal_patch_size * self .patch_size * self .patch_size ,
2329+ ]
2330+ return self .random_float_tensor (shape , framework = framework , dtype = float_dtype )
2331+
2332+
2333+ class DummyQwen2VLVisionEmbedMergerInputGenerator (DummyVisionInputGenerator ):
2334+ SUPPORTED_INPUT_NAMES = ("hidden_states" , "attention_mask" , "rotary_pos_emb" )
2335+
2336+ def __init__ (
2337+ self ,
2338+ task : str ,
2339+ normalized_config : NormalizedVisionConfig ,
2340+ batch_size : int = 1 ,
2341+ num_channels : int = DEFAULT_DUMMY_SHAPES ["num_channels" ],
2342+ width : int = 420 ,
2343+ height : int = 420 ,
2344+ ** kwargs ,
2345+ ):
2346+ self .batch_size = batch_size
2347+ self .height = height
2348+ self .width = width
2349+ self .num_channels = num_channels
2350+ self .temporal_patch_size = normalized_config .config .temporal_patch_size
2351+ self .patch_size = normalized_config .config .patch_size
2352+ self .embed_dim = normalized_config .config .embed_dim
2353+ self .num_heads = normalized_config .config .num_heads
2354+
2355+ def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
2356+ grid_h , grid_w = self .height // self .patch_size , self .width // self .patch_size
2357+ grid_t = self .batch_size
2358+
2359+ if input_name == "hidden_states" :
2360+ return self .random_float_tensor (
2361+ [grid_t * grid_h * grid_w , self .embed_dim ], framework = framework , dtype = float_dtype
2362+ )
2363+
2364+ if input_name == "attention_mask" :
2365+ return self .random_mask_tensor (
2366+ [1 , grid_t * grid_h * grid_w , grid_t * grid_h * grid_w ], framework = framework , dtype = float_dtype
2367+ )
2368+
2369+ if input_name == "rotary_pos_emb" :
2370+ dim = self .embed_dim // self .num_heads // 2
2371+ return self .random_float_tensor ([grid_h * grid_t * grid_w , dim ], framework = framework , dtype = float_dtype )
2372+
2373+
2374+ class Qwen2VLConfigBehavior (str , enum .Enum ):
2375+ LANGUAGE = "language"
2376+ VISION_EMBEDDINGS = "vision_embeddings"
2377+ VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
2378+ TEXT_EMBEDDINGS = "text_embeddings"
2379+
2380+
2381+ @register_in_tasks_manager ("qwen2-vl" , * ["image-text-to-text" ], library_name = "transformers" )
2382+ class Qwen2VLOpenVINOConfig (OnnxConfig ):
2383+ SUPPORTED_BEHAVIORS = [model_type .value for model_type in Qwen2VLConfigBehavior ]
2384+ NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
2385+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator ,)
2386+ MIN_TRANSFORMERS_VERSION = version .parse ("4.45.0" )
2387+
2388+ def __init__ (
2389+ self ,
2390+ config : "PretrainedConfig" ,
2391+ task : str = "feature-extraction" ,
2392+ int_dtype : str = "int64" ,
2393+ float_dtype : str = "fp32" ,
2394+ behavior : Qwen2VLConfigBehavior = Qwen2VLConfigBehavior .VISION_EMBEDDINGS ,
2395+ preprocessors : Optional [List [Any ]] = None ,
2396+ ):
2397+ super ().__init__ (
2398+ config = config ,
2399+ task = task ,
2400+ int_dtype = int_dtype ,
2401+ float_dtype = float_dtype ,
2402+ preprocessors = preprocessors ,
2403+ )
2404+ self ._behavior = behavior
2405+ self ._orig_config = config
2406+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS and hasattr (config , "vision_config" ):
2407+ self ._config = config .vision_config
2408+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2409+ self .DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEMbedInputGenerator ,)
2410+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER and hasattr (config , "vision_config" ):
2411+ self ._config = config .vision_config
2412+ self ._normalized_config = self .NORMALIZED_CONFIG_CLASS (self ._config )
2413+ self .DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedMergerInputGenerator ,)
2414+
2415+ @staticmethod
2416+ def get_model_for_behavior (model , behavior : Union [str , Qwen2VLConfigBehavior ]):
2417+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2418+ behavior = Qwen2VLConfigBehavior (behavior )
2419+
2420+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2421+ return model
2422+
2423+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2424+ vision_embeddings = model .visual .patch_embed
2425+ vision_embeddings .config = model .config .vision_config
2426+ return vision_embeddings
2427+
2428+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2429+ vision_emb_merger = model .visual
2430+ vision_emb_merger .config = model .config .vision_config
2431+ return vision_emb_merger
2432+
2433+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2434+ text_embedding = model .model .embed_tokens
2435+ text_embedding .config = model .config
2436+ return text_embedding
2437+
2438+ def with_behavior (
2439+ self ,
2440+ behavior : Union [str , Qwen2VLConfigBehavior ],
2441+ ):
2442+ """
2443+ Creates a config for different behaviour.
2444+ Args:
2445+ behavior ([`ConfigBehavior`]):
2446+ The behavior to use for the new instance.
2447+ """
2448+ if isinstance (behavior , str ) and not isinstance (behavior , Qwen2VLConfigBehavior ):
2449+ behavior = Qwen2VLConfigBehavior (behavior )
2450+
2451+ if behavior == Qwen2VLConfigBehavior .TEXT_EMBEDDINGS :
2452+ return get_vlm_text_embeddings_config ("qwen2" , self ._orig_config , self .int_dtype , self .float_dtype )
2453+
2454+ if behavior == Qwen2VLConfigBehavior .LANGUAGE :
2455+ return get_vlm_text_generation_config (
2456+ "qwen2" ,
2457+ self ._orig_config ,
2458+ self .int_dtype ,
2459+ self .float_dtype ,
2460+ model_patcher = Qwen2VLLanguageModelPatcher ,
2461+ dummy_input_generator = DummyQwen2VLLMInputGenerator ,
2462+ inputs_update = {"position_ids" : {1 : "batch_size" , 2 : "sequence_length" }},
2463+ )
2464+
2465+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS :
2466+ return self .__class__ (
2467+ self ._orig_config ,
2468+ task = self .task ,
2469+ int_dtype = self .int_dtype ,
2470+ float_dtype = self .float_dtype ,
2471+ behavior = behavior ,
2472+ preprocessors = self ._preprocessors ,
2473+ )
2474+ if behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2475+ return self .__class__ (
2476+ self ._orig_config ,
2477+ task = self .task ,
2478+ int_dtype = self .int_dtype ,
2479+ float_dtype = self .float_dtype ,
2480+ behavior = behavior ,
2481+ preprocessors = self ._preprocessors ,
2482+ )
2483+
2484+ def patch_model_for_export (
2485+ self , model : Union ["PreTrainedModel" , "TFPreTrainedModel" ], model_kwargs : Optional [Dict [str , Any ]] = None
2486+ ):
2487+ model_kwargs = model_kwargs or {}
2488+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2489+ return Qwen2VLVisionEmbMergerPatcher (self , model , model_kwargs )
2490+ return super ().patch_model_for_export (model , model_kwargs )
2491+
2492+ @property
2493+ def inputs (self ) -> Dict [str , Dict [int , str ]]:
2494+ if self ._behavior == Phi3VisionConfigBehavior .VISION_EMBEDDINGS :
2495+ return {"hidden_states" : {0 : "patch_thw_grid" , 1 : "patch_temporal_channels" }}
2496+ if self ._behavior == Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER :
2497+ return {
2498+ "hidden_states" : {0 : "sequence_length" },
2499+ "attention_mask" : {1 : "sequence_length" , 2 : "sequence_length" },
2500+ "rotary_pos_emb" : {0 : "sequence_length" },
2501+ }
2502+
2503+ @property
2504+ def outputs (self ) -> Dict [str , Dict [int , str ]]:
2505+ if self ._behavior in [Qwen2VLConfigBehavior .VISION_EMBEDDINGS , Qwen2VLConfigBehavior .VISION_EMBEDDINGS_MERGER ]:
2506+ return {"last_hidden_state" : {0 : "seq_len" }}
2507+ return {}
0 commit comments