Skip to content

Commit 7eb7aae

Browse files
author
d-savchenkov
committed
[quantization] Introduce wrapper for Qwen3VLVisionModel
This change introduces QuantQwen3VLVisionModel wrapper to support post-training quantization of Qwen3VLVisionModel operation. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent fc67e35 commit 7eb7aae

File tree

4 files changed

+967
-0
lines changed

4 files changed

+967
-0
lines changed
Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import unittest
17+
from typing import Tuple
18+
19+
import torch
20+
21+
from tico.quantization.config.ptq import PTQConfig
22+
from tico.quantization.wrapq.mode import Mode
23+
from tico.quantization.wrapq.utils.version import has_transformers_for
24+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model import (
25+
QuantQwen3VLVisionModel,
26+
)
27+
28+
29+
skip_msg = "transformers not installed — skipping Qwen3VLVisionModel tests"
30+
31+
32+
@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg)
33+
class TestQuantQwen3VLVisionModel(unittest.TestCase):
34+
fp_model: torch.nn.Module
35+
hidden_size: int
36+
num_heads: int
37+
head_dim: int
38+
theta: float
39+
transformers_version: str
40+
41+
@classmethod
42+
def setUpClass(cls):
43+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
44+
Qwen3VLVisionConfig,
45+
)
46+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
47+
48+
# Use smaller sizes for testing
49+
cfg = Qwen3VLVisionConfig(
50+
hidden_size=64,
51+
num_heads=4,
52+
depth=2, # Smaller depth for faster testing
53+
temporal_patch_size=2,
54+
patch_size=16,
55+
)
56+
57+
# Ensure eager attention implementation so outputs are deterministic
58+
# and do not require GPU flash attention kernels.
59+
# Some versions use `_attn_implementation`, others expose `attn_implementation`.
60+
if not hasattr(cfg, "_attn_implementation"):
61+
setattr(cfg, "_attn_implementation", "eager")
62+
else:
63+
cfg._attn_implementation = "eager"
64+
65+
cls.fp_model = Qwen3VLVisionModel(cfg)
66+
cls.hidden_size = cfg.hidden_size
67+
cls.num_heads = cfg.num_heads
68+
cls.head_dim = cls.hidden_size // cls.num_heads
69+
cls.theta = (
70+
cls.fp_model.rotary_pos_emb.theta
71+
if hasattr(cls.fp_model.rotary_pos_emb, "theta")
72+
else 10000.0
73+
)
74+
75+
def _create_test_inputs(
76+
self, grid_thw: Tuple[int, int, int] = (1, 8, 8)
77+
) -> Tuple[torch.Tensor, torch.Tensor]:
78+
"""Helper to create test inputs for VisionModel."""
79+
t, h, w = grid_thw
80+
num_patches = t * h * w
81+
# Input shape: (seq_len, in_channels * temporal_patch_size * patch_size * patch_size)
82+
hidden_states = torch.randn(
83+
num_patches, 3 * 2 * 16 * 16
84+
) # 3 channels, 2 temporal, 16x16 patches
85+
grid_tensor = torch.tensor([grid_thw])
86+
return hidden_states, grid_tensor
87+
88+
def test_get_vision_grid_thw_from_config(self):
89+
"""Test _get_vision_grid_thw static method with valid config."""
90+
# Test with valid config
91+
ptq_config = PTQConfig()
92+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
93+
94+
grid_thw = QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config)
95+
expected = torch.tensor([[1, 8, 8]])
96+
self.assertTrue(torch.equal(grid_thw, expected))
97+
self.assertEqual(grid_thw.shape, (1, 3))
98+
99+
def test_get_vision_grid_thw_missing_config(self):
100+
"""Test _get_vision_grid_thw raises error when config is missing."""
101+
# Test with None config
102+
with self.assertRaises(ValueError) as context:
103+
QuantQwen3VLVisionModel._get_vision_grid_thw(None)
104+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
105+
106+
# Test with config without vision_grid_thw
107+
ptq_config = PTQConfig()
108+
with self.assertRaises(ValueError) as context:
109+
QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config)
110+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
111+
112+
def test_precompute_rope_inv_freq(self):
113+
"""Test _precompute_rope_inv_freq static method."""
114+
dim = 32
115+
theta = 10000.0
116+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(dim, theta)
117+
118+
self.assertEqual(inv_freq.shape, (dim // 2,))
119+
self.assertTrue(torch.all(inv_freq > 0))
120+
# Check that frequencies are decreasing
121+
self.assertTrue(torch.all(inv_freq[:-1] >= inv_freq[1:]))
122+
123+
def test_precompute_cu_seqlens(self):
124+
"""Test _precompute_cu_seqlens static method."""
125+
grid_thw = torch.tensor(
126+
[[1, 8, 8], [2, 4, 4]]
127+
) # 1*8*8 + 2*4*4 = 96 total patches
128+
cu_seqlens = QuantQwen3VLVisionModel._precompute_cu_seqlens(grid_thw)
129+
130+
self.assertEqual(cu_seqlens.shape, (4,)) # 3 images + 1 padding
131+
self.assertEqual(cu_seqlens[0].item(), 0)
132+
self.assertEqual(cu_seqlens[1].item(), 64) # 1st image: 1*8*8 = 64 patches
133+
self.assertEqual(cu_seqlens[2].item(), 80) # 2nd image: 1*4*4 = 16 patches
134+
self.assertEqual(
135+
cu_seqlens[3].item(), 96
136+
) # 3rd image: 1*4*4 = 16 patches, total 96
137+
138+
def test_precompute_rope_position_embeddings(self):
139+
"""Test _precompute_rope_position_embeddings static method."""
140+
grid_thw = torch.tensor([[1, 8, 8]])
141+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(
142+
dim=self.head_dim // 2,
143+
theta=self.theta,
144+
)
145+
146+
cos_t, sin_t = QuantQwen3VLVisionModel._precompute_rope_position_embeddings(
147+
merge_size=2,
148+
rope_inv_freq=inv_freq,
149+
grid_thw=grid_thw,
150+
)
151+
152+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
153+
self.assertEqual(cos_t.shape, (expected_patches, self.head_dim))
154+
self.assertEqual(sin_t.shape, (expected_patches, self.head_dim))
155+
156+
def test_rot_pos_emb(self):
157+
"""Test _rot_pos_emb static method."""
158+
grid_thw = torch.tensor([[1, 8, 8]])
159+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(
160+
dim=self.head_dim // 2,
161+
theta=self.theta,
162+
)
163+
164+
rotary_pos_emb = QuantQwen3VLVisionModel._rot_pos_emb(2, inv_freq, grid_thw)
165+
166+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
167+
self.assertEqual(rotary_pos_emb.shape, (expected_patches, self.head_dim // 2))
168+
169+
def test_create_freq_table(self):
170+
"""Test _create_freq_table static method."""
171+
seqlen = 64
172+
inv_freq = torch.randn(16) # dim//2 = 32//2 = 16
173+
freq_table = QuantQwen3VLVisionModel._create_freq_table(seqlen, inv_freq)
174+
175+
self.assertEqual(freq_table.shape, (seqlen, inv_freq.shape[0]))
176+
177+
def test_fast_pos_embed_interpolate(self):
178+
"""Test _fast_pos_embed_interpolate static method."""
179+
grid_thw = torch.tensor([[1, 8, 8]])
180+
pos_embeds = QuantQwen3VLVisionModel._fast_pos_embed_interpolate(
181+
merge_size=2,
182+
num_grid_per_side=48, # From model config
183+
pos_embedder=self.fp_model.pos_embed,
184+
grid_thw=grid_thw,
185+
)
186+
187+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
188+
self.assertEqual(pos_embeds.shape, (expected_patches, self.hidden_size))
189+
190+
def test_init_with_valid_config(self):
191+
"""Test successful initialization with valid config."""
192+
ptq_config = PTQConfig()
193+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
194+
195+
q_model = QuantQwen3VLVisionModel(
196+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
197+
)
198+
199+
# Check that buffers are registered
200+
self.assertTrue(hasattr(q_model, "cu_seqlens_template"))
201+
self.assertTrue(hasattr(q_model, "pos_embed_template"))
202+
self.assertTrue(hasattr(q_model, "rope_inv_freq"))
203+
self.assertTrue(hasattr(q_model, "rope_cos_template"))
204+
self.assertTrue(hasattr(q_model, "rope_sin_template"))
205+
206+
# Check submodule wrapping
207+
self.assertIsNotNone(q_model.patch_embed)
208+
self.assertEqual(len(q_model.blocks), len(self.fp_model.blocks))
209+
self.assertIsNotNone(q_model.merger)
210+
self.assertEqual(
211+
len(q_model.deepstack_merger_list), len(self.fp_model.deepstack_merger_list)
212+
)
213+
214+
def test_init_missing_vision_grid_thw(self):
215+
"""Test initialization fails without vision_grid_thw."""
216+
ptq_config = PTQConfig()
217+
218+
with self.assertRaises(ValueError) as context:
219+
QuantQwen3VLVisionModel(
220+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
221+
)
222+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
223+
224+
def test_mode_transitions(self):
225+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
226+
ptq_config = PTQConfig()
227+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
228+
q_model = QuantQwen3VLVisionModel(
229+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
230+
)
231+
self.assertIs(q_model._mode, Mode.NO_QUANT)
232+
233+
q_model.enable_calibration()
234+
self.assertIs(q_model._mode, Mode.CALIB)
235+
236+
# Run forward pass during calibration
237+
hidden_states, grid_thw = self._create_test_inputs((1, 8, 8))
238+
_ = q_model(hidden_states, grid_thw)
239+
240+
q_model.freeze_qparams()
241+
self.assertIs(q_model._mode, Mode.QUANT)
242+
243+
def test_forward_grid_mismatch_during_calibration(self):
244+
"""Test forward pass fails with mismatched grid_thw during calibration."""
245+
ptq_config = PTQConfig()
246+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
247+
q_model = QuantQwen3VLVisionModel(
248+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
249+
)
250+
q_model.enable_calibration()
251+
252+
# Try with different grid
253+
hidden_states, grid_thw = self._create_test_inputs((1, 4, 4))
254+
255+
with self.assertRaises(AssertionError) as context:
256+
_ = q_model(hidden_states, grid_thw)
257+
self.assertIn("grid_thw", str(context.exception))
258+
259+
def test_observer_count(self):
260+
"""Test that the wrapper has the correct number of observers."""
261+
ptq_config = PTQConfig()
262+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
263+
q_model = QuantQwen3VLVisionModel(
264+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
265+
)
266+
267+
observers = list(q_model._all_observers())
268+
# Should have 4 local observers: pos_embeds, pos_add, rope_cos, rope_sin
269+
self.assertEqual(len(observers), 4)
270+
271+
def test_precomputed_embeddings_shape(self):
272+
"""Test that precomputed embeddings have correct shapes."""
273+
ptq_config = PTQConfig()
274+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
275+
q_model = QuantQwen3VLVisionModel(
276+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
277+
)
278+
279+
expected_patches = math.prod(
280+
getattr(ptq_config, "vision_grid_thw")
281+
) # t * h * w = 1 * 8 * 8 = 64
282+
283+
# Check position embeddings
284+
self.assertEqual(
285+
q_model.pos_embed_template.shape, (expected_patches, self.hidden_size)
286+
)
287+
288+
# Check RoPE embeddings
289+
self.assertEqual(
290+
q_model.rope_cos_template.shape,
291+
(expected_patches, self.head_dim),
292+
)
293+
self.assertEqual(
294+
q_model.rope_sin_template.shape,
295+
(expected_patches, self.head_dim),
296+
)
297+
298+
# Check cumulative sequence lengths
299+
self.assertEqual(q_model.cu_seqlens_template.shape, (2,)) # 1 image + 1 padding
300+
301+
def test_registration_in_registry(self):
302+
"""Test that Qwen3VLVisionModel is properly registered."""
303+
from tico.quantization.wrapq.wrappers.registry import lookup
304+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
305+
306+
wrapper_cls = lookup(Qwen3VLVisionModel)
307+
self.assertIs(wrapper_cls, QuantQwen3VLVisionModel)
308+
309+
def test_output_structure(self):
310+
"""Test that output has correct structure."""
311+
ptq_config = PTQConfig()
312+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
313+
q_model = QuantQwen3VLVisionModel(
314+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
315+
)
316+
q_model.enable_calibration()
317+
318+
hidden_states, grid_thw = self._create_test_inputs((1, 8, 8))
319+
_ = q_model(hidden_states, grid_thw)
320+
321+
q_model.freeze_qparams()
322+
323+
with torch.no_grad():
324+
q_out = q_model(hidden_states, grid_thw)
325+
326+
# Check shapes
327+
expected_patches = math.prod(
328+
getattr(ptq_config, "vision_grid_thw")
329+
) # t * h * w = 1 * 8 * 8
330+
331+
# The structure of q_out depends on transformers version
332+
merged_hidden_states = (
333+
q_out.pooler_output
334+
if QuantQwen3VLVisionModel.transformers_version == "new"
335+
else q_out[0]
336+
)
337+
338+
self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4)
339+
340+
def test_different_grid_sizes(self):
341+
"""Test with different grid sizes."""
342+
test_cases = [
343+
((1, 4, 4), "small_image"),
344+
((1, 6, 6), "medium_image"),
345+
((1, 8, 8), "large_image"),
346+
]
347+
348+
grid_thw_list: tuple[int, int, int]
349+
description: str
350+
for grid_thw_list, description in test_cases:
351+
with self.subTest(description=description):
352+
ptq_config = PTQConfig()
353+
setattr(ptq_config, "vision_grid_thw", grid_thw_list)
354+
q_model = QuantQwen3VLVisionModel(
355+
self.fp_model, qcfg=ptq_config, fp_name=f"test_model_{description}"
356+
)
357+
358+
hidden_states, grid_thw = self._create_test_inputs(grid_thw_list)
359+
360+
q_model.enable_calibration()
361+
_ = q_model(hidden_states, grid_thw)
362+
q_model.freeze_qparams()
363+
364+
with torch.no_grad():
365+
q_out = q_model(hidden_states, grid_thw)
366+
367+
# The structure of q_out depends on transformers version
368+
merged_hidden_states = (
369+
q_out.pooler_output
370+
if QuantQwen3VLVisionModel.transformers_version == "new"
371+
else q_out[0]
372+
)
373+
374+
expected_patches = math.prod(grid_thw_list) # t * h * w
375+
self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4)

0 commit comments

Comments
 (0)