-
Notifications
You must be signed in to change notification settings - Fork 25
Support image-text-to-text task #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. | ||
| embedding_config = IntxWeightOnlyConfig( | ||
| weight_dtype=torch.int8, | ||
| granularity=PerAxis(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not groupwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's iterate on this later, currently this quantization config works fine.
| ) | ||
|
|
||
| if qlinear_config: | ||
| logging.info("Quantizing linear layers.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could try peraxis here for encoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment above
| logger.warning(f"task was provided and set to {task} but not used, will be ignored") | ||
| inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) | ||
| logging.info(f"Inferred task from model class: {inferred_task}") | ||
| logger.warning(f"task was provided and set to {task}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems ok, I had to do this too but @guangy10 thoughts on this?
| return exported_program | ||
|
|
||
|
|
||
| class ImageEncoderExportableModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should look into if the vision embeddings -> multlimodal projector is gemma specific or generally applicable across the board for encoders? It's possible that other vision models have a few extra steps in here. In that case maybe it makes sense to just call it GemmaImageEncoderExportableModule, maybe create a new dir and put it into exporters/executorch/models/gemma for per-model exportable modules
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is hf transformers pattern here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The similar pattern is that they just write model-specific code for new models in modular_.py, so I think it's fine that we have some model-specific code
| return image_features | ||
|
|
||
|
|
||
| class ImageTextToTextExportableModule(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember you had code for verifying the ExportedProgram E2E in the original draft PR, can add that to def generate() here and add the test for it too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it common to have generate() implemented here?
setup.py
Outdated
| "optimum~=1.24", | ||
| "executorch>=0.6.0", | ||
| "transformers==4.51.3", | ||
| "transformers==4.53.2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should upgrade transformers in separate PR since it could be problematic
| # Branch out for float vs bool mask | ||
| # assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix." | ||
| attention_mask = attention_mask.reshape(-1, max_seq_len) | ||
| attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is giving a weird issue in verifying the e2e workflow using ExportedProgram. I forgot what exactly though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not needed. please undo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we definitely need this, otherwise e2e won’t work.
| special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) | ||
| special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) | ||
| image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larryliu0820 so not doing this in runtime means we make assumptions on where the image tokens go in the prompt, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the runner will have to take in a vector of inputs, then prefill sequentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the sequential prefilling a bit strange in the runner mainly because it is assuming format on chat template. that image tokens are coming last. You really need to do masked scatter, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The runner knows nothing about the chat template. It only sees [image, text, image..]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that but where image tokens goes or in future speech tokens go in property of the model's chat template, isnt it? So whether it is managed in the runner or the layer above it doesnt matter, but it would have be accounted for somewhere
|
|
||
| if ( | ||
| hasattr(model.config.text_config, "layer_types") | ||
| and getattr(model.config.text_config, "sliding_window", None) is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this only works for gemma3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure haha. Will use it to enable a few more models.
| Returns: | ||
| image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). | ||
| """ | ||
| vision_outputs = self.model.vision_tower( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this relies on the fact that there is vision_tower attr on the model that is for vision encoder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this should work for llava as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah but is this something we can rely on? Like upstreaming this change might be difficult? Mainly the question is, how much of the model structure information you are exploiting
| """ | ||
| vision_outputs = self.model.vision_tower( | ||
| pixel_values=pixel_values | ||
| ).last_hidden_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here and the next line
| sliding_window = self.metadata.get("sliding_window", float("inf")) | ||
| max_dim = min(max_seq_len, sliding_window) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does similar constraint exist for sliding window in decoder only lm?
| RemoveRedundantTransposes, | ||
| ) | ||
|
|
||
| mutated_gm = RemoveRedundantTransposes()(exported_program.module())[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should run this pass for other exported models as well
| ) | ||
|
|
||
| token_embeddings_exported_program = torch.export.export( | ||
| exportable_module.model.model.language_model.get_input_embeddings(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont follow this. we already export exported_module.model. So I would have expected that get_input_embedding is traced as part of that? I guess thats not the case when input_embeds != None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_input_embeddings() has not been traced because in the language model we specialized on input_embeds != None and skipped the token embedding layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I dont quite like the fact that we are exploiting the information from model code though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that’s inevitable and I hope transformers folks can give some guarantees lol. Like model.vision_model and model.get_image_features()
| weight_dtype=torch.int4, | ||
| weight_granularity=PerGroup(32), | ||
| ) | ||
| quantize_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quantizing vision model?
This PR has some code adopted from
transformers. We put it inoptimum-executorchso that we can fast iterate on the stack. Eventually we want to upstream changes totransformers. See details below.Exportable Modules
TorchExportableModuleWithHybridCacheA wrapper module that makes decoder-only language models exportable with
torch.exportusingHybridCache. This is a forked version ofTorchExportableModuleForDecoderOnlyLMwith some modifications to supportinputs_embeds.Note: This class should be upstreamed to transformers. We keep it here so that we can iterate quickly.
TorchExportableModuleForImageTextLMA wrapper for text decoder model in a vision-language model. It is very similar to
TorchExportableModuleForDecoderOnlyLMbut instead of takinginput_idsthis module takesinputs_embeds. This is because we want to be able to take both token embeddings and image embeddings as inputs.Note: This class should be upstreamed to transformers. Please find this PR for more details: huggingface/transformers#39836 once that lands we can cleanup the class here.
ImageEncoderExportableModuleA wrapper for vision encoder models that projects vision features to language model space. Commonly implemented as
get_image_features()in HuggingFace transformers. For example:Gemma3Model.get_image_features().ImageTextToTextExportableModuleA wrapper of
torch.nn.Moduleforimage-text-to-texttask. Providesexport()API that generates anExportedProgram. It will be consumed byxnnpack.pyrecipe to generate ExecuTorch program.Usage
Testing
Run tests with: