Skip to content

Conversation

@jackzhxng
Copy link
Collaborator

@jackzhxng jackzhxng commented Oct 10, 2025

 The image depicts a clear and vivid view of a cityscape with a prominent focus on the Statue of Liberty in the foreground. The Statue of Liberty is centrally positioned and prominently displayed, standing on a small, rocky platform that extends into the water. The water in the background is a deep blue, with a few
 ships visible
⚠️ DISCLAIMER: Pythoi-based perf measurements are approximate and may not match absolute speeds on Android/iOS apps. They are intended for relative comparisons——e.g. SDPA vs. custom SDPA, FP16 vs. FP32——so you can gauge performance improvements from each optimization step. For end-to-end, platform-accurate benchmar
ks, please use the official ExecuTorch apps:
  • iOS:     https://github.com/pytorch/executorch/tree/main/extension/benchmark/apple/Benchmark
  • Android: https://github.com/pytorch/executorch/tree/main/extension/benchmark/android/benchmark

PyTorchObserver {"prompt_tokens": 1197, "generated_tokens": 64, "model_load_start_ms": 0, "model_load_end_ms": 0, "inference_start_ms": 1760545277225, "token_encode_end_ms": 1760545277506, "model_execution_start_ms": 0, "model_execution_end_ms": 0, "inference_end_ms": 1760545314593, "prompt_eval_end_ms": 1760545308
404, "first_token_ms": 1760545308515, "aggregate_sampling_time_ms": 37068, "SCALING_FACTOR_UNITS_PER_SECOND": 1000}
        Prompt Tokens: 1197 Generated Tokens: 64
        Model Load Time:                0.000000 (seconds)
        Total inference time:           37.368000 (seconds)              Rate:  1.712695 (tokens/second)
                Prompt evaluation:      31.179000 (seconds)              Rate:  38.391225 (tokens/second)
                Generated 64 tokens:    6.189000 (seconds)               Rate:  10.340927 (tokens/second)
        Time to first generated token:  31.290000 (seconds)
        Sampling time over 1261 tokens: 37.068000 (seconds)

INFO     root:utils.py:87 Starting perplexity check with multimodal model 'HuggingFaceTB/SmolVLM-Instruct' ...
Fetching 2 files: 100%|███████████████████████████████████████████████████| 2/2 [00:00<00:00, 5349.88it/s]
INFO     root:utils.py:135 ✓ Perplexity check passed: 2.84 <= 5
PASSED

Requires the following Transformers changes (also WIP) - huggingface/transformers#41629

@jackzhxng jackzhxng changed the title Add SmolVLM [WIP] Add SmolVLM Oct 10, 2025
"""
import types

def export_friendly_forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
Copy link
Collaborator Author

@jackzhxng jackzhxng Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp This part could not export because of the data dependent loop, I unroll it here to just handle one image. Here's the original code -
https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/modeling_idefics3.py#L149

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, i thought it was fixed in huggingface/transformers#39614

Copy link
Collaborator Author

@jackzhxng jackzhxng Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that one just exports the vision encoder, this one exports the get_image_features function which calls the vision encoder. I'm thinking that some code in this function might be confusing to export cc @tugsbayasgalan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh i see, prob it is the part where images are unpadded and that code is value dependant.

For my understanding, can we export the vision and the LM part separately but not the merging logic? Most get_image_features might not be 100% exportable for VLMs, so keeping it outside of export can work better on long term

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we export get_image_features is that if we only exported the vision encoder, we would need to write the rest of the merging logic in C++ which is difficult to do and harder to scale. I wonder if it's possible to upstream an exportable version of get_image_features? If not, I'm happy to just monkey patch like this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to upstream an exportable version of get_image_features

This would be the best solution, and if you want to submit a PR it'll very welcome. I see that we're passing pixel_attention_mask=None in all cases in this PR but the vision backbone is still not exportable?

h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=torch.long)
w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=torch.long)

# This replaces bucketize(x, boundaries=[1/N, 2/N, ...], right=True) ≈ floor(x * N), which
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the second change, we don't have a kernel for aten.bucketize so I compute it manually

@jackzhxng jackzhxng changed the title [WIP] Add SmolVLM Add SmolVLM Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants