Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 207 additions & 7 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,16 @@ def __init__(
super().__init__()
self.model = model

# For multimodal models, use text_config if available
config = getattr(self.model.config, 'text_config', self.model.config)

# Verify the model is configured for HybridCache
if not self.model.config.use_cache:
if not config.use_cache:
raise AssertionError("Model must have caching enabled")

# Initialize the HybridCache
self.cache = HybridCache(
config=self.model.config,
config=config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=self.model.device,
Expand All @@ -435,27 +438,39 @@ def __init__(

def forward(
self,
input_ids: torch.Tensor,
cache_position: torch.Tensor,
input_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.

Args:
input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
input_ids (`torch.Tensor`, *optional*):
Tensor representing current input token id to the module.
inputs_embeds (`torch.Tensor`, *optional*):
Tensor representing input embeddings. Used for multimodal models.
cache_position (`torch.Tensor`, *optional*):
Tensor representing current input position in the cache.

Returns:
torch.Tensor: Logits output from the model.
"""
batch_size = input_ids.shape[0]
if input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if cache_position is None:
raise ValueError("cache_position is required")

batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]

# Generate position_ids from cache_position
position_ids = cache_position.unsqueeze(0).expand(batch_size, -1)

# Forward pass with the model
outputs = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=None,
position_ids=position_ids,
past_key_values=self.cache,
Expand Down Expand Up @@ -853,3 +868,188 @@ def sdpa_mask_without_vmap(
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask


class TorchExportableModuleForImageTextLM(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this is same as TorchExportableModuleForDecoderOnlyLM with the only diff that the input model in multimodal. We could re-use TorchExportableModuleForDecoderOnlyLM and ask users to export the language backbone explicitly like TorchExportableModuleForDecoderOnlyLM(model.language_model)

"""
A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
specifically for image-text LM with cache. This module ensures that the
exported model is compatible with further lowering and execution in `ExecuTorch`.
"""

def __init__(
self,
model: PreTrainedModel,
max_batch_size: int = 1,
max_cache_len: int = 4096,
):
"""
Initializes the exportable module for image-text models.

Args:
model (`PreTrainedModel`): The pretrained model to wrap.
max_batch_size (int): Maximum batch size for the cache.
max_cache_len (int): Maximum sequence length for the cache.

Raises:
ValueError: If the model is configured with an unsupported cache implementation.
"""
super().__init__()

if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False:
raise ValueError("The model must have caching enabled to be performant.")
Comment on lines +899 to +900
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.get_text_config() is more reliable because it is not always called text_config. And since it's accessed a lot below, we can just save it in self.text_config = model.get_text_config()


if hasattr(model.config.text_config, "layer_types") and getattr(model.config.text_config, "sliding_window", None) is not None:
self.model = TorchExportableModuleWithHybridCache(model, max_batch_size, max_cache_len)
else:
# If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
# there is only 1 type of layers, so export will use `StaticCache` by default.
logging.info(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model)

# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use public API - model.set_attn_implementation("sdpa_without_vmap")


def forward(
self,
inputs_embeds: torch.Tensor,
cache_position: torch.Tensor,
) -> torch.Tensor:
"""
Forward pass of the module, which is compatible with the ExecuTorch llm runner.

Args:
inputs_embeds (`torch.Tensor`): Tensor representing input embeddings.
cache_position (`torch.Tensor`): Tensor representing current input position in the cache.

Returns:
torch.Tensor: Logits output from the model.
"""
return self.model.forward(inputs_embeds=inputs_embeds, cache_position=cache_position)

def export(
self,
inputs_embeds: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None,
) -> torch.export.ExportedProgram:
"""
Export the wrapped module using `torch.export`.

Args:
inputs_embeds (`Optional[torch.Tensor]`):
Tensor representing input embeddings. If not provided, a default tensor will be used.
cache_position (`Optional[torch.Tensor]`):
Tensor representing current input position in the cache. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified.
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
if hasattr(self.model, "base_model_prefix"):
base = getattr(self.model, self.model.base_model_prefix, self.model)
model_device = base.device
elif hasattr(self.model, "model"):
model_device = self.model.model.device
else:
model_device = "cpu"
logging.warning(
"TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default."
)
Comment on lines +954 to +963
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think model.device would be fine. The model here is the language backbone


seq_length = 3
hidden_size = self.model.model.config.text_config.hidden_size if hasattr(self.model.model.config, 'text_config') else self.model.model.config.hidden_size

example_inputs_embeds = (
inputs_embeds if inputs_embeds is not None
else torch.zeros(1, seq_length, hidden_size, dtype=torch.float32, device=model_device)
)
example_cache_position = (
cache_position if cache_position is not None
else torch.arange(seq_length, dtype=torch.long, device=model_device)
)

if dynamic_shapes is None:
seq_len_dim = torch.export.Dim("seq_length_dim", max=seq_length)
dynamic_shapes = {
"inputs_embeds": {1: seq_len_dim},
"cache_position": {0: seq_len_dim},
}

exported_program = torch.export.export(
self.model,
args=(),
kwargs={"inputs_embeds": example_inputs_embeds, "cache_position": example_cache_position},
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)
return exported_program


class ImageEncoderExportableModule(torch.nn.Module):
"""
A wrapper module designed to make a vision encoder-only model exportable with `torch.export`.
This module ensures that the exported model is compatible with ExecuTorch.
"""

def __init__(self, model):
super().__init__()
self.model = model

def forward(self, pixel_values):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most models currently require extra inputs such as num_patches, image_attn_mask etc.

"""
Projects the last hidden state from the vision model into language model space.

Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`.
"""
vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = self.model.multi_modal_projector(vision_outputs)
return image_features
Comment on lines +1014 to +1016
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ig self.model is the multimodal model. We should use model.get_image_features() which handles the pipeline correctly for the given model, because some models might need extra ops on top of this


def export(
self,
pixel_values: Optional[torch.Tensor] = None,
dynamic_shapes: Optional[dict] = None,
strict: Optional[bool] = None,
) -> torch.export.ExportedProgram:
"""
Export the vision encoder using `torch.export`.

Args:
pixel_values (`Optional[torch.Tensor]`):
Input images tensor. If not provided, a default tensor will be used.
dynamic_shapes (`Optional[dict]`):
Dynamic shapes to use for export if specified.
strict(`Optional[bool]`):
Flag to instruct `torch.export` to use `torchdynamo`.
"""
if hasattr(self.model, "vision_tower") and hasattr(self.model.vision_tower, "config"):
image_size = self.model.vision_tower.config.image_size
num_channels = getattr(self.model.vision_tower.config, "num_channels", 3)
else:
# Default values for vision models
image_size = 224
num_channels = 3

example_pixel_values = (
pixel_values if pixel_values is not None
else torch.randn(1, num_channels, image_size, image_size, dtype=torch.float32)
)

exported_program = torch.export.export(
self,
args=(example_pixel_values,),
kwargs={},
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else False,
)
return exported_program
Loading