Skip to content

[quantization] Introduce wrapper for Qwen3VLVisionModel#536

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
dvsav:quant_vision_model
Mar 19, 2026
Merged

[quantization] Introduce wrapper for Qwen3VLVisionModel#536
mhs4670go merged 1 commit intoSamsung:mainfrom
dvsav:quant_vision_model

Conversation

@dvsav
Copy link
Copy Markdown
Contributor

@dvsav dvsav commented Mar 5, 2026

This change introduces QuantQwen3VLVisionModel wrapper to support post-training quantization of Qwen3VLVisionModel operation.

Why?

Qwen3VLVisionModel module is used in the image encoder part of VLMs.
Trying to quantize Qwen3VLVisionModel via PTQ generates exception PTQQuantizer: no quantization wrapper for Qwen3VLVisionModel.

What

This change introduces:

  • class QuantQwen3VLVisionModel (tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py)
  • adds it to _CORE_MODULES (tico/quantization/wrapq/wrappers/registry.py)
  • unit tests class TestQuantQwen3VLVisionModel (test/quantization/wrapq/wrappers/nn/test_quant_vision_model.py)
  • Example of Qwen3VLVisionModel quantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_vision_model.py).

Design

  • This code follows the practice demonstrated in tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py, namely, the precomputation of position embeddings beforehand to avoid their computation during inference.
  • See also the discussion on precomputing position embeddings: [quantization] Introduce wrapper for Qwen3VLVisionRotaryEmbedding #496
  • Position embeddings are precomputed and stored in pos_embed_template.
  • RoPE position embeddings are precomputed and stored in rope_cos_template and rope_sin_template.
  • QuantQwen3VLVisionModel implementation deliberately uses numerous static methods (independent of the self object). This brings the benefits of functional programming making the dependencies and the data flow explicit and makes the code more unit-testable.

Unit Tests

Unit tests results with coverage information:

$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py -v
=================================================================== test session starts ====================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0 -- /home/sdmitry/myenv/bin/python
cachedir: .pytest_cache
rootdir: /home/sdmitry/TICO
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 17 items

test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_create_freq_table                        PASSED [  5%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_different_grid_sizes                     PASSED [ 10%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_fast_pos_embed_interpolate               PASSED [ 15%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_forward_grid_mismatch_during_calibration PASSED [ 21%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_get_vision_grid_thw_from_config          PASSED [ 31%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_get_vision_grid_thw_missing_config       PASSED [ 36%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_init_missing_vision_grid_thw             PASSED [ 42%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_init_with_valid_config                   PASSED [ 47%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_mode_transitions                         PASSED [ 52%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_observer_count                           PASSED [ 57%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_output_structure                         PASSED [ 63%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_precompute_cu_seqlens                    PASSED [ 68%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_precompute_rope_inv_freq                 PASSED [ 78%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_precompute_rope_position_embeddings      PASSED [ 84%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_precomputed_embeddings_shape             PASSED [ 89%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_registration_in_registry                 PASSED [ 94%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py::TestQuantQwen3VLVisionModel::test_rot_pos_emb                              PASSED [100%]

========================================================== 17 passed, 3 subtests passed in 8.94s ===========================================================

Coverage info (irrelevant files skipped):

$ coverage report -m
Name                                                                   Stmts   Miss  Cover   Missing
----------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attn.py             105      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_block.py             42      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py               33      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py            164      3    98%   249, 402-405
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py       25      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py      36      1    97%   101
tico/quantization/wrapq/wrappers/registry.py                               36      1    97%   256
...
----------------------------------------------------------------------------------------------------
TOTAL                                                                   11197   6910    38%

Example Script

PEIR depends on the number of Qwen3VLVisionBlocks (Qwen3VLVisionConfig.depth). So below I'm providing several script runs for different depths.

$ python tico/quantization/wrapq/examples/qwen/quantize_vision_model.py

depth: 1
Input shape: (1, 3, 2, 384, 384)
grid_thw: [[1, 24, 24]]
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.340295
│ PEIR       : 23.146836 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 4.5┤                                            │
    │                                            │
    │                                     • •    │
 2.9┤                                 •••••••••  │
    │                              ••••••••••••  │
    │                           ••••••••••••••   │
    │                      •••••••••••••••••     │
 1.3┤                    ••••••••••••••••••      │
    │                  •••••••••••••••••••       │
    │               •••••••••••••••••••          │
-0.2┤              ••••••••••••••••••            │
    │            •••••••••••••••••••             │
    │         •••••••••••••••••••                │
    │        •••••••••••••••••                   │
-1.8┤      ••••••••••••••••                      │
    │    ••••••••••••••••                        │
    │    •••••••••••••                           │
-3.4┤  ••••••••••••                              │
    │     •  •                                   │
    │                                            │
    │                                            │
-5.0┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.0       -2.6       -0.2       2.1       4.5

Circle model saved as 'qwen3vl_vision_model.q.circle'

depth: 2
Input shape: (1, 3, 2, 384, 384)
grid_thw: [[1, 24, 24]]
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.501045
│ PEIR       : 30.942653 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 5.8┤                                            │
    │                                            │
    │                                            │
 3.9┤                                            │
    │                                    •  •    │
    │                            ••••••••••••    │
    │                        •••••••••••••••• •  │
 2.0┤                     •••••••••••••••••••    │
    │                 ••••••••••••••••••••       │
    │                ••••••••••••••••••••        │
 0.1┤            •••••••••••••••••••••••         │
    │        •••••••••••••••••••••••••           │
    │      • ••••••••••••••••••••••              │
    │     •••••••••••••••••••••  •               │
-1.8┤     ••••••••••••••••••                     │
    │   •••••••••••••••••••                      │
    │  ••••••••••••••••                          │
-3.7┤   •••••••••                                │
    │                                            │
    │                                            │
    │                                            │
-5.6┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.6       -2.7        0.1       3.0       5.8

Circle model saved as 'qwen3vl_vision_model.q.circle'

depth: 3
Input shape: (1, 3, 2, 384, 384)
grid_thw: [[1, 24, 24]]
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.625923
│ PEIR       : 37.749843 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 5.7┤                                            │
    │                                            │
    │                                            │
 3.8┤                                            │
    │                               ••  •• ••••  │
    │                        • •••••••••••••••   │
    │                     •••••••••••••••••••••  │
 1.8┤                   •••••••••••••••••••••••  │
    │             • •••••••••••••••••••••••• •   │
    │          • •••••••••••••••••••••••••••     │
-0.1┤         •••••••••••••••••••••••••••        │
    │     •  ••••••••••••••••••••••••••          │
    │      ••••••••••••••••••••••••• ••          │
    │   •••••••••••••••••••••••••                │
-2.0┤    ••••••••••••••••••••••                  │
    │   •••••••••••••••••••                      │
    │  • •••••••••••••• •                        │
-4.0┤   • • ••• ••                               │
    │                                            │
    │                                            │
    │                                            │
-5.9┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -5.9       -3.0       -0.1       2.8       5.7

Circle model saved as 'qwen3vl_vision_model.q.circle'

depth: 24
Input shape: (1, 3, 2, 384, 384)
grid_thw: [[1, 24, 24]]
┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 2.005036
│ PEIR       : 51.661849 %
└──────────────────────────────────────────────────────
     ┌───────────────────────────────────────────┐
 14.2┤                                           │
     │                                           │
     │                                           │
  9.3┤                                           │
     │                                           │
     │                                           │
     │                                           │
  4.3┤                       •     •             │
     │            • •••••••••••••••••••••••••    │
     │        •••••••••••••••••••••••••••••••••  │
 -0.7┤  • •  ••••••••••••••••••••••••••••••••••  │
     │    ••••••••••••••••••••••••••••••••••     │
     │      •••••••••••••••••••••••• ••          │
     │                                           │
 -5.7┤                                           │
     │                                           │
     │                                           │
-10.6┤                                           │
     │                                           │
     │                                           │
     │                                           │
-15.6┤                                           │
     └┬──────────┬─────────┬──────────┬─────────┬┘
    -15.6      -8.1      -0.7        6.8     14.2

Circle model saved as 'qwen3vl_vision_model.q.circle'

@dvsav dvsav force-pushed the quant_vision_model branch 17 times, most recently from 18b9abd to 4a8bb88 Compare March 13, 2026 08:16
@dvsav
Copy link
Copy Markdown
Contributor Author

dvsav commented Mar 13, 2026

Reference Code of Qwen3VLVisionModel

Below is the source code of transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel (transformers version 5.3.0 - latest at the time of writing):

class Qwen3VLVisionModel(Qwen3VLPreTrainedModel):
    config: Qwen3VLVisionConfig
    _no_split_modules = ["Qwen3VLVisionBlock"]
    _can_record_outputs = {
        "hidden_states": Qwen3VLVisionBlock,
        "attentions": Qwen3VLVisionAttention,
    }

    def __init__(self, config, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size
        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

        self.patch_embed = Qwen3VLVisionPatchEmbed(
            config=config,
        )

        self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
        self.num_grid_per_side = int(config.num_position_embeddings**0.5)

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([Qwen3VLVisionBlock(config) for _ in range(config.depth)])
        self.merger = Qwen3VLVisionPatchMerger(
            config=config,
            use_postshuffle_norm=False,
        )

        self.deepstack_visual_indexes = config.deepstack_visual_indexes
        self.deepstack_merger_list = nn.ModuleList(
            [
                Qwen3VLVisionPatchMerger(
                    config=config,
                    use_postshuffle_norm=True,
                )
                for _ in range(len(config.deepstack_visual_indexes))
            ]
        )

        self.gradient_checkpointing = False

        self.post_init()

    def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
        merge_size = self.spatial_merge_size

        max_hw = int(grid_thw[:, 1:].max().item())
        freq_table = self.rotary_pos_emb(max_hw)  # (max_hw, dim // 2)
        device = freq_table.device

        total_tokens = int(torch.prod(grid_thw, dim=1).sum().item())
        pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device)

        offset = 0
        for num_frames, height, width in grid_thw:
            merged_h, merged_w = height // merge_size, width // merge_size

            block_rows = torch.arange(merged_h, device=device)  # block row indices
            block_cols = torch.arange(merged_w, device=device)  # block col indices
            intra_row = torch.arange(merge_size, device=device)  # intra-block row offsets
            intra_col = torch.arange(merge_size, device=device)  # intra-block col offsets

            # Compute full-resolution positions
            row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
            col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]

            row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)
            col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1)

            coords = torch.stack((row_idx, col_idx), dim=-1)

            if num_frames > 1:
                coords = coords.repeat(num_frames, 1)

            num_tokens = coords.shape[0]
            pos_ids[offset : offset + num_tokens] = coords
            offset += num_tokens

        embeddings = freq_table[pos_ids]  # lookup rotary embeddings
        embeddings = embeddings.flatten(1)
        return embeddings

    def fast_pos_embed_interpolate(self, grid_thw):
        grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
        device = self.pos_embed.weight.device

        idx_list = [[] for _ in range(4)]
        weight_list = [[] for _ in range(4)]

        for t, h, w in zip(grid_ts, grid_hs, grid_ws):
            h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
            w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)

            h_idxs_floor = h_idxs.int()
            w_idxs_floor = w_idxs.int()
            h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
            w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

            dh = h_idxs - h_idxs_floor
            dw = w_idxs - w_idxs_floor

            base_h = h_idxs_floor * self.num_grid_per_side
            base_h_ceil = h_idxs_ceil * self.num_grid_per_side

            indices = [
                (base_h[None].T + w_idxs_floor[None]).flatten(),
                (base_h[None].T + w_idxs_ceil[None]).flatten(),
                (base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
                (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
            ]

            weights = [
                ((1 - dh)[None].T * (1 - dw)[None]).flatten(),
                ((1 - dh)[None].T * dw[None]).flatten(),
                (dh[None].T * (1 - dw)[None]).flatten(),
                (dh[None].T * dw[None]).flatten(),
            ]

            for i in range(4):
                idx_list[i].extend(indices[i].tolist())
                weight_list[i].extend(weights[i].tolist())

        idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
        weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
        pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
        patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

        patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])

        patch_pos_embeds_permute = []
        merge_size = self.config.spatial_merge_size
        for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
            pos_embed = pos_embed.repeat(t, 1)
            pos_embed = (
                pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
                .permute(0, 1, 3, 2, 4, 5)
                .flatten(0, 4)
            )
            patch_pos_embeds_permute.append(pos_embed)
        patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
        return patch_pos_embeds

    @check_model_inputs
    def forward(
        self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
    ) -> tuple | BaseModelOutputWithDeepstackFeatures:
        """
        Args:
            hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
                The final hidden states of the model.
            grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
                The temporal, height and width of feature shape of each image in LLM.

        Returns:
            `torch.Tensor`: hidden_states.
        """
        hidden_states = self.patch_embed(hidden_states)

        pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
        hidden_states = hidden_states + pos_embeds

        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        seq_len, _ = hidden_states.size()
        hidden_states = hidden_states.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852 for more information
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        deepstack_feature_lists = []
        for layer_num, blk in enumerate(self.blocks):
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                position_embeddings=position_embeddings,
                **kwargs,
            )
            if layer_num in self.deepstack_visual_indexes:
                deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
                    hidden_states
                )
                deepstack_feature_lists.append(deepstack_feature)

        merged_hidden_states = self.merger(hidden_states)

        return BaseModelOutputWithDeepstackFeatures(
            last_hidden_state=hidden_states,
            pooler_output=merged_hidden_states,
            deepstack_features=deepstack_feature_lists,
        )

Note that the structure of Qwen3VLVisionModel.forward return value depends on transformers version which I had to take into account in this PR:

    # transformers==5.3.0
    def forward(...)
        ...
        merged_hidden_states = self.merger(hidden_states)
        return BaseModelOutputWithDeepstackFeatures(
            last_hidden_state=hidden_states,
            pooler_output=merged_hidden_states,
            deepstack_features=deepstack_feature_lists,
        )
    # transformers==4.57.0
    def forward(...)
        ...
        hidden_states = self.merger(hidden_states)
        return hidden_states, deepstack_feature_lists

@dvsav dvsav marked this pull request as ready for review March 13, 2026 08:34
@dvsav dvsav force-pushed the quant_vision_model branch from 4a8bb88 to 97d5ec3 Compare March 13, 2026 09:04
@mhs4670go
Copy link
Copy Markdown
Contributor

When I ran the example script with transformers 4.57.3, I got below errors.

File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2318, in _load_attr
    result = BuiltinVariable(getattr).call_function(
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1111, in call_function
    return handler(tx, args, kwargs)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 945, in builtin_dispatch
    rv = fn(tx, args, kwargs)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 850, in call_self_handler
    result = self_handler(tx, *args, **kwargs)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/builtin.py", line 1897, in call_getattr
    return obj.var_getattr(tx, name)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/misc.py", line 1075, in var_getattr
    attr_value = getattr(self.value, name)
torch._dynamo.exc.InternalTorchDynamoError: AttributeError: module 'transformers.models.qwen3_vl.modeling_qwen3_vl' has no attribute 'BaseModelOutputWithDeepstackFeatures'

from user code:
   File "/home/seongwoo/TICO/tico/quantization/wrapq/wrappers/ptq_wrapper.py", line 68, in forward
    return self.wrapped(*args, **kwargs)
  File "/home/seongwoo/TICO/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/seongwoo/TICO/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py", line 407, in forward
    from transformers.models.qwen3_vl.modeling_qwen3_vl import (

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

def _get_vision_grid_thw(qcfg: Optional[PTQConfig]) -> torch.Tensor:
"""Extract vision_grid_thw from config for precomputing RoPE embeddings"""
if qcfg and hasattr(qcfg, "vision_grid_thw"):
grid_thw = torch.tensor([getattr(qcfg, "vision_grid_thw")])
Copy link
Copy Markdown
Contributor Author

@dvsav dvsav Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if PTQConfig attribute is the right place to store vision_grid_thw as the latter has nothing to do with quantization, but I couldn't come up with a better idea.

Copy link
Copy Markdown
Contributor

@mhs4670go mhs4670go Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right! I was gonna work to resolve this after merging this PR.

right place to store vision_grid_thw as the latter has nothing to do with quantization,

I think it's okay to give this to PTQConfig. Because the information is needed to wrap the module.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #560.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see #560.

I guess, that's a subject of another PR.
(Just in case) Please let me know if you'd like me to implement the #560 approach in this PR.

@dvsav dvsav force-pushed the quant_vision_model branch 4 times, most recently from 059685e to 7eb7aae Compare March 18, 2026 09:14
@dvsav
Copy link
Copy Markdown
Contributor Author

dvsav commented Mar 18, 2026

When I ran the example script with transformers 4.57.3, I got below errors.

Hi @mhs4670go
Thanks for catching that 👍 I've fixed the error.

@mhs4670go
Copy link
Copy Markdown
Contributor

Thank you for the great work. But, I think old/new style version branching looks fragile here.

We already use runtime capability detection in torch_compat.py, so it would be better to follow the same pattern for transformers as well.

In this case, the wrapper does not really care about an “old” or “new” version. What it actually needs to know is whether BaseModelOutputWithDeepstackFeatures is available at runtime. So instead of storing a vague state like "old" / "new", I would suggest moving this to a feature probe such as:

@functools.lru_cache(maxsize=None)
def qwen3_vl_has_deepstack_model_output() -> bool:
    try:
        module = importlib.import_module(
            "transformers.models.qwen3_vl.modeling_qwen3_vl"
        )
    except ImportError:
        return False
    return hasattr(module, "BaseModelOutputWithDeepstackFeatures")

Then in the wrapper:

self.has_deepstack_model_output = qwen3_vl_has_deepstack_model_output()

and later:

if self.has_deepstack_model_output:
    ...
else:
    ...

This keeps the code aligned with our existing compatibility policy:

  • branch on features, not package versions
  • make the condition name reflect the actual capability being checked

I would also recommend placing this in a new file such as:

tico/quantization/utils/transformers_compat.py

so future transformers-specific probes can live in one place.

Here's the full verison that I implemented. Please include this to the PR.

"""
Runtime capability-detection helpers for Hugging Face `transformers`.

Instead of branching on specific package versions such as
`transformers >= 5.x`, use these helpers to detect whether the exact
symbol or behavior required by the code is available at runtime.

Each probe is cached once per process with `functools.lru_cache`,
so repeated checks have negligible overhead.
"""

import functools
import importlib


@functools.lru_cache(maxsize=None)
def qwen3_vl_has_deepstack_model_output() -> bool:
    """
    Return whether Qwen3-VL exposes
    `BaseModelOutputWithDeepstackFeatures` in its modeling module.

    This wrapper only needs to know whether the structured return type is
    available. Using feature detection keeps the code resilient to
    backports, forward ports, and non-linear package versioning.

    Returns
    -------
    bool
        ``True`` if
        `transformers.models.qwen3_vl.modeling_qwen3_vl`
        defines `BaseModelOutputWithDeepstackFeatures`,
        otherwise ``False``.
    """
    try:
        module = importlib.import_module(
            "transformers.models.qwen3_vl.modeling_qwen3_vl"
        )
    except ImportError:
        return False

    return hasattr(module, "BaseModelOutputWithDeepstackFeatures")

@dvsav dvsav force-pushed the quant_vision_model branch 3 times, most recently from 793bc21 to 08ab2ea Compare March 18, 2026 17:29
This change introduces QuantQwen3VLVisionModel wrapper to support post-training quantization of Qwen3VLVisionModel operation.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
@dvsav dvsav force-pushed the quant_vision_model branch from 08ab2ea to c832fad Compare March 18, 2026 17:34
@dvsav
Copy link
Copy Markdown
Contributor Author

dvsav commented Mar 18, 2026

We already use runtime capability detection in torch_compat.py, so it would be better to follow the same pattern for transformers as well.

👍 Done

Copy link
Copy Markdown
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@mhs4670go mhs4670go merged commit e8f4ed3 into Samsung:main Mar 19, 2026
7 checks passed
@dvsav dvsav deleted the quant_vision_model branch March 19, 2026 06:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants