Skip to content

Conversation

@larryliu0820
Copy link
Collaborator

@larryliu0820 larryliu0820 commented Jul 23, 2025

This PR has some code adopted from transformers. We put it in optimum-executorch so that we can fast iterate on the stack. Eventually we want to upstream changes to transformers. See details below.

Exportable Modules

TorchExportableModuleWithHybridCache

A wrapper module that makes decoder-only language models exportable with torch.export using HybridCache. This is a forked version of TorchExportableModuleForDecoderOnlyLM with some modifications to support inputs_embeds.

Note: This class should be upstreamed to transformers. We keep it here so that we can iterate quickly.

TorchExportableModuleForImageTextLM

A wrapper for text decoder model in a vision-language model. It is very similar to TorchExportableModuleForDecoderOnlyLM but instead of taking input_ids this module takes inputs_embeds. This is because we want to be able to take both token embeddings and image embeddings as inputs.

Note: This class should be upstreamed to transformers. Please find this PR for more details: huggingface/transformers#39836 once that lands we can cleanup the class here.

ImageEncoderExportableModule

A wrapper for vision encoder models that projects vision features to language model space. Commonly implemented as get_image_features() in HuggingFace transformers. For example: Gemma3Model.get_image_features().

ImageTextToTextExportableModule

A wrapper of torch.nn.Module for image-text-to-text task. Provides export() API that generates an ExportedProgram. It will be consumed by xnnpack.py recipe to generate ExecuTorch program.

Usage

from optimum.executorch import ExecuTorchModelForMultimodalCausalLM

model_id = "google/gemma-3-4b-it"

model = ExecuTorchModelForMultimodalCausalLM.from_pretrained(
    model_id,
    recipe="xnnpack",
    task="image-text-to-text",
    export=True,
    use_custom_sdpa=True,
    use_custom_kv_cache=True,
    qlinear=True,
    qembedding_config=True,
)

Testing

Run tests with:

RUN_SLOW=1 pytest tests/models/test_modeling_gemma3.py::ExecuTorchModelIntegrationTest::test_gemma3_image_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8we

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@larryliu0820 larryliu0820 changed the title DRAFT - Support image-text-to-text task Support image-text-to-text task Aug 1, 2025
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
embedding_config = IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=PerAxis(0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not groupwise?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's iterate on this later, currently this quantization config works fine.

)

if qlinear_config:
logging.info("Quantizing linear layers.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could try peraxis here for encoder

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above

logger.warning(f"task was provided and set to {task} but not used, will be ignored")
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class)
logging.info(f"Inferred task from model class: {inferred_task}")
logger.warning(f"task was provided and set to {task}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems ok, I had to do this too but @guangy10 thoughts on this?

return exported_program


class ImageEncoderExportableModule(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should look into if the vision embeddings -> multlimodal projector is gemma specific or generally applicable across the board for encoders? It's possible that other vision models have a few extra steps in here. In that case maybe it makes sense to just call it GemmaImageEncoderExportableModule, maybe create a new dir and put it into exporters/executorch/models/gemma for per-model exportable modules

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is hf transformers pattern here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The similar pattern is that they just write model-specific code for new models in modular_.py, so I think it's fine that we have some model-specific code

return image_features


class ImageTextToTextExportableModule(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember you had code for verifying the ExportedProgram E2E in the original draft PR, can add that to def generate() here and add the test for it too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it common to have generate() implemented here?

setup.py Outdated
"optimum~=1.24",
"executorch>=0.6.0",
"transformers==4.51.3",
"transformers==4.53.2",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should upgrade transformers in separate PR since it could be problematic

# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, max_seq_len)
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is giving a weird issue in verifying the e2e workflow using ExportedProgram. I forgot what exactly though

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not needed. please undo

Copy link
Collaborator Author

@larryliu0820 larryliu0820 Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we definitely need this, otherwise e2e won’t work.

special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larryliu0820 so not doing this in runtime means we make assumptions on where the image tokens go in the prompt, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the runner will have to take in a vector of inputs, then prefill sequentially.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the sequential prefilling a bit strange in the runner mainly because it is assuming format on chat template. that image tokens are coming last. You really need to do masked scatter, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The runner knows nothing about the chat template. It only sees [image, text, image..]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that but where image tokens goes or in future speech tokens go in property of the model's chat template, isnt it? So whether it is managed in the runner or the layer above it doesnt matter, but it would have be accounted for somewhere


if (
hasattr(model.config.text_config, "layer_types")
and getattr(model.config.text_config, "sliding_window", None) is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this only works for gemma3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure haha. Will use it to enable a few more models.

Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
vision_outputs = self.model.vision_tower(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this relies on the fact that there is vision_tower attr on the model that is for vision encoder

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this should work for llava as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but is this something we can rely on? Like upstreaming this change might be difficult? Mainly the question is, how much of the model structure information you are exploiting

"""
vision_outputs = self.model.vision_tower(
pixel_values=pixel_values
).last_hidden_state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here and the next line

Comment on lines +300 to +301
sliding_window = self.metadata.get("sliding_window", float("inf"))
max_dim = min(max_seq_len, sliding_window) - 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does similar constraint exist for sliding window in decoder only lm?

RemoveRedundantTransposes,
)

mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should run this pass for other exported models as well

)

token_embeddings_exported_program = torch.export.export(
exportable_module.model.model.language_model.get_input_embeddings(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont follow this. we already export exported_module.model. So I would have expected that get_input_embedding is traced as part of that? I guess thats not the case when input_embeds != None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_input_embeddings() has not been traced because in the language model we specialized on input_embeds != None and skipped the token embedding layer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I dont quite like the fact that we are exploiting the information from model code though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that’s inevitable and I hope transformers folks can give some guarantees lol. Like model.vision_model and model.get_image_features()

weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
)
quantize_(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quantizing vision model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants