diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 6fa0e6348d66..6a88f47e6157 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -415,13 +415,16 @@ def __init__( super().__init__() self.model = model + # For multimodal models, use text_config if available + config = getattr(self.model.config, 'text_config', self.model.config) + # Verify the model is configured for HybridCache - if not self.model.config.use_cache: + if not config.use_cache: raise AssertionError("Model must have caching enabled") # Initialize the HybridCache self.cache = HybridCache( - config=self.model.config, + config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.model.device, @@ -435,20 +438,31 @@ def __init__( def forward( self, - input_ids: torch.Tensor, - cache_position: torch.Tensor, + input_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the module, which is compatible with the ExecuTorch llm runner. Args: - input_ids (`torch.Tensor`): Tensor representing current input token id to the module. - cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + input_ids (`torch.Tensor`, *optional*): + Tensor representing current input token id to the module. + inputs_embeds (`torch.Tensor`, *optional*): + Tensor representing input embeddings. Used for multimodal models. + cache_position (`torch.Tensor`, *optional*): + Tensor representing current input position in the cache. Returns: torch.Tensor: Logits output from the model. """ - batch_size = input_ids.shape[0] + if input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if cache_position is None: + raise ValueError("cache_position is required") + + batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] # Generate position_ids from cache_position position_ids = cache_position.unsqueeze(0).expand(batch_size, -1) @@ -456,6 +470,7 @@ def forward( # Forward pass with the model outputs = self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, attention_mask=None, position_ids=position_ids, past_key_values=self.cache, @@ -853,3 +868,188 @@ def sdpa_mask_without_vmap( if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) return causal_mask + + +class TorchExportableModuleForImageTextLM(torch.nn.Module): + """ + A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`, + specifically for image-text LM with cache. This module ensures that the + exported model is compatible with further lowering and execution in `ExecuTorch`. + """ + + def __init__( + self, + model: PreTrainedModel, + max_batch_size: int = 1, + max_cache_len: int = 4096, + ): + """ + Initializes the exportable module for image-text models. + + Args: + model (`PreTrainedModel`): The pretrained model to wrap. + max_batch_size (int): Maximum batch size for the cache. + max_cache_len (int): Maximum sequence length for the cache. + + Raises: + ValueError: If the model is configured with an unsupported cache implementation. + """ + super().__init__() + + if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: + raise ValueError("The model must have caching enabled to be performant.") + + if hasattr(model.config.text_config, "layer_types") and getattr(model.config.text_config, "sliding_window", None) is not None: + self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len) + else: + # If `layer_types` is not specified explicitly in the config or `sliding_window` is null, + # there is only 1 type of layers, so export will use `StaticCache` by default. + logging.info( + "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." + ) + self.model = TorchExportableModuleWithStaticCache(model) + + # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap" + + def forward( + self, + inputs_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass of the module, which is compatible with the ExecuTorch llm runner. + + Args: + inputs_embeds (`torch.Tensor`): Tensor representing input embeddings. + cache_position (`torch.Tensor`): Tensor representing current input position in the cache. + + Returns: + torch.Tensor: Logits output from the model. + """ + return self.model.forward(inputs_embeds=inputs_embeds, cache_position=cache_position) + + def export( + self, + inputs_embeds: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the wrapped module using `torch.export`. + + Args: + inputs_embeds (`Optional[torch.Tensor]`): + Tensor representing input embeddings. If not provided, a default tensor will be used. + cache_position (`Optional[torch.Tensor]`): + Tensor representing current input position in the cache. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + if hasattr(self.model, "base_model_prefix"): + base = getattr(self.model, self.model.base_model_prefix, self.model) + model_device = base.device + elif hasattr(self.model, "model"): + model_device = self.model.model.device + else: + model_device = "cpu" + logging.warning( + "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." + ) + + seq_length = 3 + hidden_size = self.model.model.config.text_config.hidden_size if hasattr(self.model.model.config, 'text_config') else self.model.model.config.hidden_size + + example_inputs_embeds = ( + inputs_embeds if inputs_embeds is not None + else torch.zeros(1, seq_length, hidden_size, dtype=torch.float32, device=model_device) + ) + example_cache_position = ( + cache_position if cache_position is not None + else torch.arange(seq_length, dtype=torch.long, device=model_device) + ) + + if dynamic_shapes is None: + seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length) + dynamic_shapes = { + "inputs_embeds": {1: seq_len_dim}, + "cache_position": {0: seq_len_dim}, + } + + exported_program = torch.export.export( + self.model, + args=(), + kwargs={"inputs_embeds": example_inputs_embeds, "cache_position": example_cache_position}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else True, + ) + return exported_program + + +class ImageEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a vision encoder-only model exportable with `torch.export`. + This module ensures that the exported model is compatible with ExecuTorch. + """ + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values): + """ + Projects the last hidden state from the vision model into language model space. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`. + """ + vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state + image_features = self.model.multi_modal_projector(vision_outputs) + return image_features + + def export( + self, + pixel_values: Optional[torch.Tensor] = None, + dynamic_shapes: Optional[dict] = None, + strict: Optional[bool] = None, + ) -> torch.export.ExportedProgram: + """ + Export the vision encoder using `torch.export`. + + Args: + pixel_values (`Optional[torch.Tensor]`): + Input images tensor. If not provided, a default tensor will be used. + dynamic_shapes (`Optional[dict]`): + Dynamic shapes to use for export if specified. + strict(`Optional[bool]`): + Flag to instruct `torch.export` to use `torchdynamo`. + """ + if hasattr(self.model, "vision_tower") and hasattr(self.model.vision_tower, "config"): + image_size = self.model.vision_tower.config.image_size + num_channels = getattr(self.model.vision_tower.config, "num_channels", 3) + else: + # Default values for vision models + image_size = 224 + num_channels = 3 + + example_pixel_values = ( + pixel_values if pixel_values is not None + else torch.randn(1, num_channels, image_size, image_size, dtype=torch.float32) + ) + + exported_program = torch.export.export( + self, + args=(example_pixel_values,), + kwargs={}, + dynamic_shapes=dynamic_shapes, + strict=strict if strict is not None else False, + ) + return exported_program diff --git a/tests/integrations/test_executorch_multimodal.py b/tests/integrations/test_executorch_multimodal.py new file mode 100644 index 000000000000..b624e70e23bb --- /dev/null +++ b/tests/integrations/test_executorch_multimodal.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import Mock, patch + +import torch + +from transformers import HfArgumentParser +from transformers.integrations.executorch import ( + ImageEncoderExportableModule, + TorchExportableModuleForImageTextLM, + TorchExportableModuleWithHybridCache, +) +from transformers.testing_utils import require_torch + + +@require_torch +class ExecuTorchMultimodalTest(unittest.TestCase): + def setUp(self): + # Mock multimodal model configuration + self.mock_config = Mock() + self.mock_config.text_config = Mock() + self.mock_config.text_config.use_cache = True + self.mock_config.text_config.hidden_size = 768 + self.mock_config.text_config.num_hidden_layers = 12 + self.mock_config.vision_config = Mock() + self.mock_config.vision_config.image_size = 224 + + # Mock model + self.mock_model = Mock() + self.mock_model.config = self.mock_config + self.mock_model.device = torch.device("cpu") + self.mock_model.dtype = torch.float32 + + def test_hybrid_cache_inputs_embeds_support(self): + """Test that TorchExportableModuleWithHybridCache supports inputs_embeds""" + with patch("transformers.integrations.executorch.HybridCache") as MockCache: + # Create exportable module + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Test forward with inputs_embeds + batch_size, seq_len, hidden_size = 1, 3, 768 + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size) + cache_position = torch.arange(seq_len) + + # Mock model output + mock_output = Mock() + mock_output.logits = torch.randn(batch_size, seq_len, 32000) # vocab_size + self.mock_model.return_value = mock_output + + # Call forward + result = exportable.forward(inputs_embeds=inputs_embeds, cache_position=cache_position) + + # Verify model was called with inputs_embeds + self.mock_model.assert_called_once() + call_kwargs = self.mock_model.call_args[1] + self.assertIn("inputs_embeds", call_kwargs) + self.assertIsNone(call_kwargs["input_ids"]) + torch.testing.assert_close(call_kwargs["inputs_embeds"], inputs_embeds) + + def test_hybrid_cache_multimodal_config(self): + """Test that TorchExportableModuleWithHybridCache uses text_config for multimodal models""" + with patch("transformers.integrations.executorch.HybridCache") as MockCache: + # Create exportable module + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Verify HybridCache was initialized with text_config + MockCache.assert_called_once() + call_args = MockCache.call_args[1] + self.assertEqual(call_args["config"], self.mock_config.text_config) + + def test_image_text_lm_module(self): + """Test TorchExportableModuleForImageTextLM initialization""" + with patch("transformers.integrations.executorch.TorchExportableModuleWithHybridCache") as MockWrapper: + with patch("transformers.integrations.executorch.ALL_MASK_ATTENTION_FUNCTIONS"): + with patch("transformers.integrations.executorch.ALL_ATTENTION_FUNCTIONS"): + # Create image-text LM module + exportable = TorchExportableModuleForImageTextLM(self.mock_model) + + # Verify it creates the appropriate wrapper + MockWrapper.assert_called_once_with(self.mock_model, 1, 4096) + + def test_image_encoder_module(self): + """Test ImageEncoderExportableModule""" + # Mock vision model + mock_vision_tower = Mock() + mock_vision_outputs = Mock() + mock_vision_outputs.last_hidden_state = torch.randn(1, 196, 768) # 14x14 patches + mock_vision_tower.return_value = mock_vision_outputs + + mock_projector = Mock() + mock_projector.return_value = torch.randn(1, 196, 768) # projected features + + mock_model = Mock() + mock_model.vision_tower = mock_vision_tower + mock_model.multi_modal_projector = mock_projector + + # Create encoder module + encoder = ImageEncoderExportableModule(mock_model) + + # Test forward pass + pixel_values = torch.randn(1, 3, 224, 224) + result = encoder.forward(pixel_values) + + # Verify calls + mock_vision_tower.assert_called_once_with(pixel_values=pixel_values) + mock_projector.assert_called_once_with(mock_vision_outputs.last_hidden_state) + + def test_error_handling(self): + """Test error handling for invalid configurations""" + # Test missing cache configuration + bad_config = Mock() + bad_config.text_config = Mock() + bad_config.text_config.use_cache = False + + bad_model = Mock() + bad_model.config = bad_config + + with self.assertRaises(ValueError): + TorchExportableModuleForImageTextLM(bad_model) + + def test_forward_validation(self): + """Test input validation in forward method""" + with patch("transformers.integrations.executorch.HybridCache"): + exportable = TorchExportableModuleWithHybridCache(self.mock_model) + + # Test missing both input_ids and inputs_embeds + with self.assertRaises(ValueError): + exportable.forward(cache_position=torch.tensor([0])) + + # Test missing cache_position + with self.assertRaises(ValueError): + exportable.forward(input_ids=torch.tensor([[1]])) + + +if __name__ == "__main__": + unittest.main() +