-
Couldn't load subscription status.
- Fork 31k
add multimodal executorch support #39832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -415,13 +415,16 @@ def __init__( | |
| super().__init__() | ||
| self.model = model | ||
|
|
||
| # For multimodal models, use text_config if available | ||
| config = getattr(self.model.config, 'text_config', self.model.config) | ||
|
|
||
| # Verify the model is configured for HybridCache | ||
| if not self.model.config.use_cache: | ||
| if not config.use_cache: | ||
| raise AssertionError("Model must have caching enabled") | ||
|
|
||
| # Initialize the HybridCache | ||
| self.cache = HybridCache( | ||
| config=self.model.config, | ||
| config=config, | ||
| max_batch_size=max_batch_size, | ||
| max_cache_len=max_cache_len, | ||
| device=self.model.device, | ||
|
|
@@ -435,27 +438,39 @@ def __init__( | |
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| cache_position: torch.Tensor, | ||
| input_ids: Optional[torch.Tensor] = None, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| cache_position: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the module, which is compatible with the ExecuTorch llm runner. | ||
|
|
||
| Args: | ||
| input_ids (`torch.Tensor`): Tensor representing current input token id to the module. | ||
| cache_position (`torch.Tensor`): Tensor representing current input position in the cache. | ||
| input_ids (`torch.Tensor`, *optional*): | ||
| Tensor representing current input token id to the module. | ||
| inputs_embeds (`torch.Tensor`, *optional*): | ||
| Tensor representing input embeddings. Used for multimodal models. | ||
| cache_position (`torch.Tensor`, *optional*): | ||
| Tensor representing current input position in the cache. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Logits output from the model. | ||
| """ | ||
| batch_size = input_ids.shape[0] | ||
| if input_ids is None and inputs_embeds is None: | ||
| raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
|
||
| if cache_position is None: | ||
| raise ValueError("cache_position is required") | ||
|
|
||
| batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] | ||
|
|
||
| # Generate position_ids from cache_position | ||
| position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) | ||
|
|
||
| # Forward pass with the model | ||
| outputs = self.model( | ||
| input_ids=input_ids, | ||
| inputs_embeds=inputs_embeds, | ||
| attention_mask=None, | ||
| position_ids=position_ids, | ||
| past_key_values=self.cache, | ||
|
|
@@ -853,3 +868,188 @@ def sdpa_mask_without_vmap( | |
| if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: | ||
| causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) | ||
| return causal_mask | ||
|
|
||
|
|
||
| class TorchExportableModuleForImageTextLM(torch.nn.Module): | ||
| """ | ||
| A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, | ||
| specifically for image-text LM with cache. This module ensures that the | ||
| exported model is compatible with further lowering and execution in `ExecuTorch`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: PreTrainedModel, | ||
| max_batch_size: int = 1, | ||
| max_cache_len: int = 4096, | ||
| ): | ||
| """ | ||
| Initializes the exportable module for image-text models. | ||
|
|
||
| Args: | ||
| model (`PreTrainedModel`): The pretrained model to wrap. | ||
| max_batch_size (int): Maximum batch size for the cache. | ||
| max_cache_len (int): Maximum sequence length for the cache. | ||
|
|
||
| Raises: | ||
| ValueError: If the model is configured with an unsupported cache implementation. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: | ||
| raise ValueError("The model must have caching enabled to be performant.") | ||
|
Comment on lines
+899
to
+900
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| if hasattr(model.config.text_config, "layer_types") and getattr(model.config.text_config, "sliding_window", None) is not None: | ||
| self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) | ||
| else: | ||
| # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, | ||
| # there is only 1 type of layers, so export will use `StaticCache` by default. | ||
| logging.info( | ||
| "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." | ||
| ) | ||
| self.model = TorchExportableModuleWithStaticCache(model) | ||
|
|
||
| # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable | ||
| ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) | ||
| ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) | ||
| self.model.model.config._attn_implementation = "sdpa_without_vmap" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use public API - |
||
|
|
||
| def forward( | ||
| self, | ||
| inputs_embeds: torch.Tensor, | ||
| cache_position: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Forward pass of the module, which is compatible with the ExecuTorch llm runner. | ||
|
|
||
| Args: | ||
| inputs_embeds (`torch.Tensor`): Tensor representing input embeddings. | ||
| cache_position (`torch.Tensor`): Tensor representing current input position in the cache. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Logits output from the model. | ||
| """ | ||
| return self.model.forward(inputs_embeds=inputs_embeds, cache_position=cache_position) | ||
|
|
||
| def export( | ||
| self, | ||
| inputs_embeds: Optional[torch.Tensor] = None, | ||
| cache_position: Optional[torch.Tensor] = None, | ||
| dynamic_shapes: Optional[dict] = None, | ||
| strict: Optional[bool] = None, | ||
| ) -> torch.export.ExportedProgram: | ||
| """ | ||
| Export the wrapped module using `torch.export`. | ||
|
|
||
| Args: | ||
| inputs_embeds (`Optional[torch.Tensor]`): | ||
| Tensor representing input embeddings. If not provided, a default tensor will be used. | ||
| cache_position (`Optional[torch.Tensor]`): | ||
| Tensor representing current input position in the cache. If not provided, a default tensor will be used. | ||
| dynamic_shapes (`Optional[dict]`): | ||
| Dynamic shapes to use for export if specified. | ||
| strict(`Optional[bool]`): | ||
| Flag to instruct `torch.export` to use `torchdynamo`. | ||
| """ | ||
| if hasattr(self.model, "base_model_prefix"): | ||
| base = getattr(self.model, self.model.base_model_prefix, self.model) | ||
| model_device = base.device | ||
| elif hasattr(self.model, "model"): | ||
| model_device = self.model.model.device | ||
| else: | ||
| model_device = "cpu" | ||
| logging.warning( | ||
| "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." | ||
| ) | ||
|
Comment on lines
+954
to
+963
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, I think |
||
|
|
||
| seq_length = 3 | ||
| hidden_size = self.model.model.config.text_config.hidden_size if hasattr(self.model.model.config, 'text_config') else self.model.model.config.hidden_size | ||
|
|
||
| example_inputs_embeds = ( | ||
| inputs_embeds if inputs_embeds is not None | ||
| else torch.zeros(1, seq_length, hidden_size, dtype=torch.float32, device=model_device) | ||
| ) | ||
| example_cache_position = ( | ||
| cache_position if cache_position is not None | ||
| else torch.arange(seq_length, dtype=torch.long, device=model_device) | ||
| ) | ||
|
|
||
| if dynamic_shapes is None: | ||
| seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) | ||
| dynamic_shapes = { | ||
| "inputs_embeds": {1: seq_len_dim}, | ||
| "cache_position": {0: seq_len_dim}, | ||
| } | ||
|
|
||
| exported_program = torch.export.export( | ||
| self.model, | ||
| args=(), | ||
| kwargs={"inputs_embeds": example_inputs_embeds, "cache_position": example_cache_position}, | ||
| dynamic_shapes=dynamic_shapes, | ||
| strict=strict if strict is not None else True, | ||
| ) | ||
| return exported_program | ||
|
|
||
|
|
||
| class ImageEncoderExportableModule(torch.nn.Module): | ||
| """ | ||
| A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. | ||
| This module ensures that the exported model is compatible with ExecuTorch. | ||
| """ | ||
|
|
||
| def __init__(self, model): | ||
| super().__init__() | ||
| self.model = model | ||
|
|
||
| def forward(self, pixel_values): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. most models currently require extra inputs such as |
||
| """ | ||
| Projects the last hidden state from the vision model into language model space. | ||
|
|
||
| Args: | ||
| pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): | ||
| The tensors corresponding to the input images. | ||
| Returns: | ||
| image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. | ||
| """ | ||
| vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state | ||
| image_features = self.model.multi_modal_projector(vision_outputs) | ||
| return image_features | ||
|
Comment on lines
+1014
to
+1016
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ig |
||
|
|
||
| def export( | ||
| self, | ||
| pixel_values: Optional[torch.Tensor] = None, | ||
| dynamic_shapes: Optional[dict] = None, | ||
| strict: Optional[bool] = None, | ||
| ) -> torch.export.ExportedProgram: | ||
| """ | ||
| Export the vision encoder using `torch.export`. | ||
|
|
||
| Args: | ||
| pixel_values (`Optional[torch.Tensor]`): | ||
| Input images tensor. If not provided, a default tensor will be used. | ||
| dynamic_shapes (`Optional[dict]`): | ||
| Dynamic shapes to use for export if specified. | ||
| strict(`Optional[bool]`): | ||
| Flag to instruct `torch.export` to use `torchdynamo`. | ||
| """ | ||
| if hasattr(self.model, "vision_tower") and hasattr(self.model.vision_tower, "config"): | ||
| image_size = self.model.vision_tower.config.image_size | ||
| num_channels = getattr(self.model.vision_tower.config, "num_channels", 3) | ||
| else: | ||
| # Default values for vision models | ||
| image_size = 224 | ||
| num_channels = 3 | ||
|
|
||
| example_pixel_values = ( | ||
| pixel_values if pixel_values is not None | ||
| else torch.randn(1, num_channels, image_size, image_size, dtype=torch.float32) | ||
| ) | ||
|
|
||
| exported_program = torch.export.export( | ||
| self, | ||
| args=(example_pixel_values,), | ||
| kwargs={}, | ||
| dynamic_shapes=dynamic_shapes, | ||
| strict=strict if strict is not None else False, | ||
| ) | ||
| return exported_program | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is same as
TorchExportableModuleForDecoderOnlyLMwith the only diff that the input model in multimodal. We could re-useTorchExportableModuleForDecoderOnlyLMand ask users to export the language backbone explicitly likeTorchExportableModuleForDecoderOnlyLM(model.language_model)