- 
                Notifications
    You must be signed in to change notification settings 
- Fork 151
[OpenVINO] Add support for GLM-4.1V-9B-Thinkin #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
546952d
              1e7f17c
              eefe590
              12a4486
              42e8acb
              2975fef
              bfa073d
              eae3354
              5a86e0b
              53a2bb4
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -96,6 +96,9 @@ | |
| FluxTransfromerModelPatcher, | ||
| Gemma2ModelPatcher, | ||
| Gemma3LMModelPatcher, | ||
| Glm4vVisionEmbMergerPatcher, | ||
| Glm4vVisionEmbeddingsPatcher, | ||
| Glm4vLanguageModelPatcher, | ||
| GptBigCodeModelPatcher, | ||
| GptJModelPatcher, | ||
| GptNeoModelPatcher, | ||
|  | @@ -154,6 +157,10 @@ | |
| def init_model_configs(): | ||
| if "open_clip" not in TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES: | ||
| TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES["open_clip"] = {} | ||
| TasksManager._CUSTOM_CLASSES[("pt", "glm4v", "image-text-to-text")] = ( | ||
| "transformers", | ||
| "Glm4vForConditionalGeneration", | ||
| ) | ||
| TasksManager._CUSTOM_CLASSES[("pt", "llava", "image-text-to-text")] = ( | ||
| "transformers", | ||
| "LlavaForConditionalGeneration", | ||
|  | @@ -4490,3 +4497,210 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): | |
| ) | ||
|  | ||
| return dummy_inputs | ||
|  | ||
|  | ||
| class DummyGlm4vVisionEmbedInputGenerator(DummyVisionInputGenerator): | ||
| SUPPORTED_INPUT_NAMES = ( | ||
| "hidden_states", | ||
| "seqlens", | ||
| "grid_thw", | ||
| "attention_mask", | ||
| "image_type_ids", | ||
| "rotary_pos_emb", | ||
| ) | ||
|  | ||
| def __init__( | ||
| self, | ||
| task: str, | ||
| normalized_config: NormalizedVisionConfig, | ||
| batch_size: int = 1, | ||
| num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], | ||
| width: int = 420, | ||
| height: int = 420, | ||
| **kwargs, | ||
| ): | ||
| self.batch_size = batch_size | ||
| self.height = height | ||
| self.width = width | ||
| self.num_channels = num_channels | ||
| self.temporal_patch_size = normalized_config.config.temporal_patch_size | ||
| self.patch_size = normalized_config.config.patch_size | ||
| if normalized_config.use_embed_dim: | ||
| self.embed_dim = ( | ||
| normalized_config.config.embed_dim | ||
| if hasattr(normalized_config.config, "embed_dim") | ||
| else normalized_config.hidden_size | ||
| ) | ||
| else: | ||
| self.embed_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size | ||
| self.num_heads = normalized_config.config.num_heads | ||
| self.spatial_merge_size = None | ||
| if hasattr(normalized_config.config, "spatial_merge_size"): | ||
| self.spatial_merge_size = normalized_config.config.spatial_merge_size | ||
|  | ||
| def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | ||
| grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size | ||
| grid_t = self.batch_size | ||
| import torch | ||
|  | ||
| if input_name == "hidden_states": | ||
| return self.random_float_tensor( | ||
| [grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype | ||
| ) | ||
|  | ||
| if input_name == "seqlens": | ||
| return torch.tensor([grid_t * grid_h * grid_w], dtype=torch.int64) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question : do we need to generate  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @echarlaix Sorry, I dont understand what "infer it directly in the patch instead" means ? could you show me a link of example code ? thanks | ||
|  | ||
| if input_name in ["attention_mask", "window_attention_mask"]: | ||
| return self.random_mask_tensor( | ||
| [1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype | ||
| ) | ||
|  | ||
| if input_name == "rotary_pos_emb": | ||
| dim = self.embed_dim // self.num_heads // 2 | ||
| return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype) | ||
|  | ||
| if input_name == "image_type_ids": | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like  | ||
| return self.random_int_tensor( | ||
| [grid_t * grid_h * grid_w, 2], max_value=grid_h, framework=framework, dtype=int_dtype | ||
| ) | ||
|  | ||
| if input_name == "grid_thw": | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question shouldn't it be inferred in the patch ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its directly from forward function | ||
| return torch.tensor([[grid_t, grid_h, grid_w]], dtype=torch.int64) | ||
|  | ||
|  | ||
| @register_in_tasks_manager("glm4v", *["image-text-to-text", "video-text-to-text"], library_name="transformers") | ||
| class Glm4vOpenVINOConfig(BaseVLMOpenVINOConfig): | ||
| SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior] | ||
| NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig | ||
| DUMMY_INPUT_GENERATOR_CLASSES = (DummyGlm4vVisionEmbedInputGenerator,) | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.54.0") | ||
|  | ||
| def __init__( | ||
| self, | ||
| config: "PretrainedConfig", | ||
| task: str = "feature-extraction", | ||
| int_dtype: str = "int64", | ||
| float_dtype: str = "fp32", | ||
| behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS, | ||
| preprocessors: Optional[List[Any]] = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__( | ||
| config=config, | ||
| task=task, | ||
| int_dtype=int_dtype, | ||
| float_dtype=float_dtype, | ||
| preprocessors=preprocessors, | ||
| ) | ||
| self._behavior = behavior | ||
| self._orig_config = config | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): | ||
| self._config = config.vision_config | ||
| self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) | ||
| self._normalized_config.use_embed_dim = False | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"): | ||
| self._config = config.vision_config | ||
| self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) | ||
| self._normalized_config.use_embed_dim = True | ||
|  | ||
| @staticmethod | ||
| def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]): | ||
| if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior): | ||
| behavior = Qwen2VLConfigBehavior(behavior) | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.LANGUAGE: | ||
| return model | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: | ||
| vision_embeddings = model.visual | ||
| vision_embeddings.config = model.config.vision_config | ||
| return vision_embeddings | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: | ||
| vision_emb_merger = model.visual | ||
| vision_emb_merger.config = model.config.vision_config | ||
| return vision_emb_merger | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS: | ||
| text_embedding = ( | ||
| model.model.embed_tokens if hasattr(model.model, "embed_tokens") else model.language_model.embed_tokens | ||
| ) | ||
| text_embedding.config = model.config | ||
| return text_embedding | ||
|  | ||
| def with_behavior( | ||
| self, | ||
| behavior: Union[str, Qwen2VLConfigBehavior], | ||
| ): | ||
| """ | ||
| Creates a config for different behaviour. | ||
| Args: | ||
| behavior ([`ConfigBehavior`]): | ||
| The behavior to use for the new instance. | ||
| """ | ||
| if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior): | ||
| behavior = Qwen2VLConfigBehavior(behavior) | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS: | ||
| return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.LANGUAGE: | ||
| return get_vlm_text_generation_config( | ||
| "qwen2", | ||
| self._orig_config, | ||
| self.int_dtype, | ||
| self.float_dtype, | ||
| model_patcher=Glm4vLanguageModelPatcher, | ||
| dummy_input_generator=DummyQwen2VLLMInputGenerator, | ||
| inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, | ||
| ) | ||
|  | ||
| if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: | ||
| return self.__class__( | ||
| self._orig_config, | ||
| task=self.task, | ||
| int_dtype=self.int_dtype, | ||
| float_dtype=self.float_dtype, | ||
| behavior=behavior, | ||
| preprocessors=self._preprocessors, | ||
| ) | ||
| if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: | ||
| return self.__class__( | ||
| self._orig_config, | ||
| task=self.task, | ||
| int_dtype=self.int_dtype, | ||
| float_dtype=self.float_dtype, | ||
| behavior=behavior, | ||
| preprocessors=self._preprocessors, | ||
| ) | ||
|  | ||
| def patch_model_for_export( | ||
| self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None | ||
| ): | ||
| model_kwargs = model_kwargs or {} | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: | ||
| return Glm4vVisionEmbMergerPatcher(self, model, model_kwargs) | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: | ||
| return Glm4vVisionEmbeddingsPatcher(self, model, model_kwargs=model_kwargs) | ||
| return super().patch_model_for_export(model, model_kwargs) | ||
|  | ||
| @property | ||
| def inputs(self) -> Dict[str, Dict[int, str]]: | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS: | ||
| return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}} | ||
| if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: | ||
| return { | ||
| "hidden_states": {0: "sequence_length"}, | ||
| "seqlens": {0: "sequence_length"}, | ||
| "grid_thw": {0: "sequence_length"}, | ||
| "attention_mask": {1: "sequence_length", 2: "sequence_length"}, | ||
| "image_type_ids": {0: "sequence_length"}, | ||
| "rotary_pos_emb": {0: "sequence_length"}, | ||
| } | ||
|  | ||
| @property | ||
| def outputs(self) -> Dict[str, Dict[int, str]]: | ||
| if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]: | ||
| return {"last_hidden_state": {0: "seq_len"}} | ||
| return {} | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here why not used
AutoModelForImageTextToTextdirectly to load all the image-text-to-text task models?https://github.com/huggingface/transformers/blob/5dba4bc7b2c1ef517ed44bba76bb70b59001c737/src/transformers/models/auto/modeling_auto.py#L941