-
Notifications
You must be signed in to change notification settings - Fork 25
Add SmolVLM #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jackzhxng
wants to merge
5
commits into
huggingface:main
Choose a base branch
from
jackzhxng:jz/smolvlm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add SmolVLM #163
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| from transformers import ( | ||
| AutoConfig, | ||
| AutoProcessor, | ||
| AutoTokenizer, | ||
| PreTrainedModel, | ||
| StaticCache, | ||
| T5ForConditionalGeneration, | ||
|
|
@@ -34,18 +35,63 @@ | |
|
|
||
| from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache | ||
|
|
||
| from .utils import apply_chat_template_with_fallback, save_config_to_constant_methods | ||
| from .utils import apply_chat_template_with_fallback, process_conversation_inputs, save_config_to_constant_methods | ||
|
|
||
| def _patch_idefics3_vision_embeddings_for_export(vision_model): | ||
| """ | ||
| Patch Idefics3VisionEmbeddings to make it export-friendly by removing data-dependent operations. | ||
| This assumes batch_size=1 and a full attention mask (all 1s). | ||
| """ | ||
| import types | ||
|
|
||
| def export_friendly_forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: | ||
| batch_size, _, max_im_h, max_im_w = pixel_values.shape | ||
|
|
||
| patch_embeds = self.patch_embedding(pixel_values) | ||
| embeddings = patch_embeds.flatten(2).transpose(1, 2) | ||
|
|
||
| nb_patches_h = max_im_h // self.patch_size | ||
| nb_patches_w = max_im_w // self.patch_size | ||
| N = self.num_patches_per_side | ||
|
|
||
| # For export, we assume full attention mask and compute position IDs statically. | ||
| # This avoids the data-dependent loop over batch dimension. | ||
| h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long) | ||
| w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long) | ||
|
|
||
| # This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which | ||
|
||
| # we don't have a kernel for at the moment. | ||
| bucket_coords_h = (h_indices * N) // nb_patches_h | ||
| bucket_coords_w = (w_indices * N) // nb_patches_w | ||
|
|
||
| bucket_coords_h = torch.clamp(bucket_coords_h, max=N - 1) | ||
| bucket_coords_w = torch.clamp(bucket_coords_w, max=N - 1) | ||
|
|
||
| pos_ids = (bucket_coords_h[:, None] * N + bucket_coords_w[None, :]).reshape(-1) | ||
| position_ids = pos_ids.unsqueeze(0).expand(batch_size, -1) | ||
| embeddings = embeddings + self.position_embedding(position_ids) | ||
| return embeddings | ||
|
|
||
| # Patch the forward method. | ||
| vision_model.embeddings.forward = types.MethodType(export_friendly_forward, vision_model.embeddings) | ||
|
|
||
|
|
||
| class VisionExportableModule(torch.nn.Module): | ||
| def __init__(self, model: torch.nn.Module): | ||
| super().__init__() | ||
| self.model = model | ||
|
|
||
| # Patch Idefics3 vision embeddings if needed | ||
| if hasattr(model, 'model') and hasattr(model.model, 'vision_model'): | ||
| model_type = getattr(model.config, 'model_type', '') | ||
| if 'idefics3' in model_type.lower(): | ||
| _patch_idefics3_vision_embeddings_for_export(model.model.vision_model) | ||
|
|
||
| def prepare_export_inputs(self): | ||
| # 1. Get export inputs | ||
| model_id = self.model.config.name_or_path | ||
| processor = AutoProcessor.from_pretrained(model_id) | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| sample_conversation_with_image = [ | ||
| { | ||
| "role": "user", | ||
|
|
@@ -54,12 +100,10 @@ def prepare_export_inputs(self): | |
| ], | ||
| }, | ||
| ] | ||
| processed_inputs = processor.apply_chat_template( | ||
| processed_inputs = process_conversation_inputs( | ||
| processor, | ||
| tokenizer, | ||
| sample_conversation_with_image, | ||
| add_generation_prompt=True, | ||
| tokenize=True, | ||
| return_dict=True, | ||
| return_tensors="pt", | ||
| ) | ||
| if "pixel_values" not in processed_inputs: | ||
| raise ValueError( | ||
|
|
@@ -76,7 +120,9 @@ def forward( | |
| self, | ||
| input_features: torch.FloatTensor, | ||
| ): | ||
| image_embeds = self.model.get_image_features(input_features) | ||
| # Pass pixel_attention_mask=None to avoid data-dependent operations during export. | ||
| # The model will create a mask full of 1s internally if None is passed. | ||
| image_embeds = self.model.get_image_features(input_features, pixel_attention_mask=None) | ||
| if isinstance(image_embeds, list): | ||
| image_embeds = torch.stack(image_embeds) | ||
| return image_embeds | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp This part could not export because of the data dependent loop, I unroll it here to just handle one image. Here's the original code -
https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/modeling_idefics3.py#L149
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, i thought it was fixed in huggingface/transformers#39614
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that one just exports the vision encoder, this one exports the get_image_features function which calls the vision encoder. I'm thinking that some code in this function might be confusing to export cc @tugsbayasgalan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ahh i see, prob it is the part where images are unpadded and that code is value dependant.
For my understanding, can we export the vision and the LM part separately but not the merging logic? Most
get_image_featuresmight not be 100% exportable for VLMs, so keeping it outside ofexportcan work better on long termThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we export
get_image_featuresis that if we only exported the vision encoder, we would need to write the rest of the merging logic in C++ which is difficult to do and harder to scale. I wonder if it's possible to upstream an exportable version ofget_image_features? If not, I'm happy to just monkey patch like thisThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be the best solution, and if you want to submit a PR it'll very welcome. I see that we're passing
pixel_attention_mask=Nonein all cases in this PR but the vision backbone is still not exportable?