-
Notifications
You must be signed in to change notification settings - Fork 31k
add multimodal executorch support #39832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add multimodal executorch support #39832
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: moshi |
14ae06d to
70e366e
Compare
This commit enhances the ExecuTorch integration to support multimodal models like Gemma-3, LLaVA, and other vision-language models. Key changes: - Enhanced TorchExportableModuleWithHybridCache to support inputs_embeds parameter and multimodal configs - Added TorchExportableModuleForImageTextLM for image-text language models - Added ImageEncoderExportableModule for vision encoders - Added a test for multimodal functionality This enables ExecuTorch export for vision-language models while maintaining backward compatibility with text-only models.
162df79 to
ff1ac47
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks a lot for the PR! I agree that we need to export the LM and vision backbones separately, and handle input merging manually. Left a few comments, imo we should make sure different types of multimodal arch can be exportable (i.e. expected inputs, config attr names)
| if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: | ||
| raise ValueError("The model must have caching enabled to be performant.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.get_text_config() is more reliable because it is not always called text_config. And since it's accessed a lot below, we can just save it in self.text_config = model.get_text_config()
| # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable | ||
| ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) | ||
| ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) | ||
| self.model.model.config._attn_implementation = "sdpa_without_vmap" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use public API - model.set_attn_implementation("sdpa_without_vmap")
| if hasattr(self.model, "base_model_prefix"): | ||
| base = getattr(self.model, self.model.base_model_prefix, self.model) | ||
| model_device = base.device | ||
| elif hasattr(self.model, "model"): | ||
| model_device = self.model.model.device | ||
| else: | ||
| model_device = "cpu" | ||
| logging.warning( | ||
| "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I think model.device would be fine. The model here is the language backbone
| super().__init__() | ||
| self.model = model | ||
|
|
||
| def forward(self, pixel_values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most models currently require extra inputs such as num_patches, image_attn_mask etc.
| vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state | ||
| image_features = self.model.multi_modal_projector(vision_outputs) | ||
| return image_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ig self.model is the multimodal model. We should use model.get_image_features() which handles the pipeline correctly for the given model, because some models might need extra ops on top of this
| return causal_mask | ||
|
|
||
|
|
||
| class TorchExportableModuleForImageTextLM(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is same as TorchExportableModuleForDecoderOnlyLM with the only diff that the input model in multimodal. We could re-use TorchExportableModuleForDecoderOnlyLM and ask users to export the language backbone explicitly like TorchExportableModuleForDecoderOnlyLM(model.language_model)
|
Hey @zucchini-nlp Thanks a lot for the thoughtful reviews. @jackzhxng will take this over the finish line in #39836 I'm gonna close this PR for the time being but hope @jackzhxng can incorporate some of your suggestions and recommendations in the PR. |
New Class: TorchExportableModuleForImageTextLM
Dedicated wrapper for image-text language models:
New Class: ImageEncoderExportableModule
Wrapper for vision encoder components: