Skip to content

Commit 6a244c2

Browse files
authored
[quantization] Introduce wrapper for Qwen3VLVisionPatchEmbed (#488)
This change introduces QuantQwen3VLVisionPatchEmbed wrapper to support post-training quantization of Qwen3VLVisionPatchEmbed module. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent 3dc76ec commit 6a244c2

File tree

4 files changed

+452
-0
lines changed

4 files changed

+452
-0
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
import unittest
17+
18+
import torch
19+
from tico.quantization.config.ptq import PTQConfig
20+
from tico.quantization.wrapq.dtypes import DType
21+
from tico.quantization.wrapq.mode import Mode
22+
from tico.quantization.wrapq.wrappers.nn.quant_conv3d import QuantConv3d
23+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed import (
24+
QuantQwen3VLVisionPatchEmbed,
25+
)
26+
27+
28+
trans_spec = importlib.util.find_spec("transformers")
29+
skip_msg = "transformers not installed — skipping Qwen3VLVisionPatchEmbed tests"
30+
31+
32+
@unittest.skipUnless(trans_spec, skip_msg)
33+
class TestQuantQwen3VLVisionPatchEmbed(unittest.TestCase):
34+
fp_patch_embed: torch.nn.Module
35+
hidden_size: int
36+
37+
@classmethod
38+
def setUpClass(cls):
39+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
40+
Qwen3VLVisionConfig,
41+
)
42+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
43+
Qwen3VLVisionPatchEmbed,
44+
)
45+
46+
cfg = Qwen3VLVisionConfig(
47+
hidden_size=64, # Smaller for testing
48+
spatial_merge_size=2,
49+
temporal_merge_size=2,
50+
)
51+
52+
cls.fp_patch_embed = Qwen3VLVisionPatchEmbed(cfg)
53+
cls.hidden_size = cfg.hidden_size
54+
55+
def test_mode_transitions(self):
56+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
57+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
58+
self.assertIs(q_patch._mode, Mode.NO_QUANT)
59+
60+
q_patch.enable_calibration()
61+
self.assertIs(q_patch._mode, Mode.CALIB)
62+
63+
# Run forward pass during calibration
64+
x = torch.randn(2, 3, 4, 32, 32)
65+
_ = q_patch(x)
66+
67+
q_patch.freeze_qparams()
68+
self.assertIs(q_patch._mode, Mode.QUANT)
69+
70+
def test_forward_diff(self):
71+
"""
72+
Test that quantized output is acceptably close to FP32 reference.
73+
After calibration and freeze, quantized output should:
74+
- Differ from FP reference (quantization actually applied)
75+
- Stay within reasonable error bounds
76+
"""
77+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
78+
q_patch.enable_calibration()
79+
80+
# Calibrate with multiple inputs
81+
for _ in range(4):
82+
x = torch.randn(2, 3, 4, 32, 32)
83+
_ = q_patch(x)
84+
85+
q_patch.freeze_qparams()
86+
87+
x = torch.randn(2, 3, 4, 32, 32)
88+
with torch.no_grad():
89+
q_out = q_patch(x)
90+
fp_out = self.fp_patch_embed(x)
91+
92+
diff = (fp_out - q_out).abs().mean().item()
93+
self.assertGreater(diff, 0.0) # not identical
94+
self.assertLess(diff, 0.4) # acceptably close
95+
self.assertEqual(fp_out.shape, q_out.shape)
96+
97+
def test_proj_override(self):
98+
"""
99+
PTQConfig overrides should propagate to the wrapped Conv3d layer.
100+
"""
101+
cfg = PTQConfig(
102+
default_dtype=DType.uint(8),
103+
overrides={
104+
"proj": {
105+
"weight": {"dtype": DType.uint(4)},
106+
"act_in": {"dtype": DType.uint(4)},
107+
"act_out": {"dtype": DType.uint(4)},
108+
}
109+
},
110+
)
111+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed, qcfg=cfg)
112+
q_conv3d = q_patch.proj.wrapped
113+
114+
self.assertIsInstance(q_conv3d, QuantConv3d)
115+
self.assertEqual(q_conv3d.obs_weight.dtype, DType.uint(4))
116+
self.assertEqual(q_conv3d.obs_act_in.dtype, DType.uint(4))
117+
self.assertEqual(q_conv3d.obs_act_out.dtype, DType.uint(4))
118+
119+
def test_activation_stats_collected(self):
120+
"""
121+
Test that activation statistics are properly collected during calibration.
122+
Both local observers and wrapped Conv3d observers should collect stats.
123+
"""
124+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
125+
q_patch.enable_calibration()
126+
127+
# Run forward pass to collect stats
128+
x = torch.randn(2, 3, 4, 32, 32)
129+
_ = q_patch(x)
130+
131+
# Check that local observers have collected stats
132+
self.assertTrue(q_patch.obs_hidden.min_val.numel() > 0)
133+
self.assertTrue(q_patch.obs_output.min_val.numel() > 0)
134+
135+
# Check that wrapped Conv3d observers have collected stats
136+
q_conv3d = q_patch.proj.wrapped
137+
self.assertTrue(q_conv3d.obs_act_in.min_val.numel() > 0)
138+
self.assertTrue(q_conv3d.obs_act_out.min_val.numel() > 0)
139+
self.assertTrue(q_conv3d.obs_weight.min_val.numel() > 0)
140+
141+
# Freeze and check qparams exist
142+
q_patch.freeze_qparams()
143+
self.assertTrue(q_patch.obs_hidden.has_qparams)
144+
self.assertTrue(q_patch.obs_output.has_qparams)
145+
self.assertTrue(q_conv3d.obs_act_in.has_qparams)
146+
self.assertTrue(q_conv3d.obs_act_out.has_qparams)
147+
self.assertTrue(q_conv3d.obs_weight.has_qparams)
148+
149+
def test_observer_count(self):
150+
"""
151+
Test that the wrapper has the correct number of observers.
152+
- 2 local observers (obs_hidden, obs_output)
153+
- 3 observers from wrapped Conv3d (obs_weight, obs_act_in, obs_act_out)
154+
"""
155+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
156+
157+
observers = list(q_patch._all_observers())
158+
self.assertEqual(len(observers), 5) # 2 local + 3 from Conv3d
159+
160+
def test_registration_in_registry(self):
161+
"""
162+
Test that Qwen3VLVisionPatchEmbed is properly registered in the wrapper registry.
163+
"""
164+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed import (
165+
QuantQwen3VLVisionPatchEmbed,
166+
)
167+
from tico.quantization.wrapq.wrappers.registry import lookup
168+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
169+
Qwen3VLVisionPatchEmbed,
170+
)
171+
172+
# Verify Qwen3VLVisionPatchEmbed maps to QuantQwen3VLVisionPatchEmbed
173+
wrapper_cls = lookup(Qwen3VLVisionPatchEmbed)
174+
self.assertIs(wrapper_cls, QuantQwen3VLVisionPatchEmbed)
175+
176+
def test_output_shape(self):
177+
"""Test that output shape is correct after patch embedding."""
178+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
179+
q_patch.enable_calibration()
180+
181+
x = torch.randn(2, 3, 4, 32, 32)
182+
_ = q_patch(x)
183+
184+
q_patch.freeze_qparams()
185+
186+
with torch.no_grad():
187+
q_out = q_patch(x)
188+
fp_out = self.fp_patch_embed(x)
189+
190+
self.assertEqual(q_out.shape, fp_out.shape)
191+
192+
def test_multiple_calibration_steps(self):
193+
"""
194+
Test that running multiple calibration iterations works correctly.
195+
Statistics should be accumulated across multiple forward passes.
196+
"""
197+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
198+
q_patch.enable_calibration()
199+
200+
# Run multiple calibration steps
201+
for i in range(5):
202+
x = torch.randn(2, 3, 4, 32, 32)
203+
_ = q_patch(x)
204+
205+
q_patch.freeze_qparams()
206+
207+
# Verify that all observers have quantization parameters
208+
self.assertTrue(q_patch.obs_hidden.has_qparams)
209+
self.assertTrue(q_patch.obs_output.has_qparams)
210+
self.assertTrue(q_patch.proj.wrapped.obs_act_in.has_qparams)
211+
self.assertTrue(q_patch.proj.wrapped.obs_act_out.has_qparams)
212+
self.assertTrue(q_patch.proj.wrapped.obs_weight.has_qparams)
213+
214+
def test_different_batch_sizes(self):
215+
"""
216+
Test that quantization works correctly with different batch sizes.
217+
"""
218+
q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed)
219+
q_patch.enable_calibration()
220+
221+
# Calibrate with one batch size
222+
calibrate_batch = torch.randn(2, 3, 4, 32, 32)
223+
for _ in range(3):
224+
_ = q_patch(calibrate_batch)
225+
q_patch.freeze_qparams()
226+
227+
# Test with different batch sizes
228+
for batch_size in [1, 2, 4]:
229+
x = torch.randn(batch_size, 3, 4, 32, 32)
230+
with torch.no_grad():
231+
q_out = q_patch(x)
232+
fp_out = self.fp_patch_embed(x)
233+
234+
self.assertEqual(q_out.shape, fp_out.shape)
235+
diff = (fp_out - q_out).abs().mean().item()
236+
self.assertLess(diff, 0.4)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import importlib.util
18+
import sys
19+
20+
import torch
21+
import torch.nn as nn
22+
23+
import tico
24+
import tico.quantization
25+
import tico.quantization.config.ptq
26+
27+
# Check if transformers is available
28+
trans_spec = importlib.util.find_spec("transformers")
29+
if trans_spec is None:
30+
print(
31+
"Error: transformers package not installed. Cannot test Qwen3VLVisionPatchEmbed."
32+
)
33+
sys.exit(1)
34+
35+
from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig
36+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed
37+
38+
39+
def generate_calibration_data(batch_size: int, sample_shape) -> list:
40+
"""Generate calibration data for PTQ"""
41+
calibration_data = []
42+
for i in range(batch_size):
43+
x = torch.randn(sample_shape)
44+
calibration_data.append(x)
45+
return calibration_data
46+
47+
48+
def main():
49+
# Create the vision patch embed model
50+
cfg = Qwen3VLVisionConfig(
51+
in_channels=3,
52+
hidden_size=1024,
53+
temporal_merge_size=2,
54+
patch_size=16,
55+
)
56+
model = Qwen3VLVisionPatchEmbed(cfg)
57+
model.eval()
58+
59+
# Qwen3VLVisionPatchEmbed(
60+
# (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
61+
# )
62+
assert model.proj.in_channels == 3
63+
assert model.proj.out_channels == 1024
64+
assert model.proj.kernel_size == (2, 16, 16)
65+
assert model.proj.stride == (2, 16, 16)
66+
67+
# Generate calibration data
68+
# Input shape: (batch_size, in_channels, depth, height, width)
69+
# Example: (2, 3, 8, 224, 224) - 2 videos, RGB, 8 frames, 224x224 resolution
70+
calibration_data = generate_calibration_data(
71+
batch_size=20, sample_shape=(2, 3, 8, 224, 224)
72+
)
73+
74+
# Configure PTQ
75+
ptq_config = tico.quantization.config.ptq.PTQConfig()
76+
77+
# Prepare the model for quantization
78+
prepared_model = tico.quantization.prepare(
79+
model, ptq_config, inplace=True # Transform the model in place
80+
)
81+
82+
# Calibrate the model (collect statistics)
83+
with torch.no_grad():
84+
for i, batch in enumerate(calibration_data):
85+
prepared_model(batch)
86+
87+
# Convert to quantized model
88+
quantized_model = tico.quantization.convert(prepared_model, inplace=True)
89+
90+
# Convert to Circle format
91+
# example_inputs shape: (batch_size, in_channels, depth, height, width)
92+
example_inputs = (torch.randn(2, 3, 8, 224, 224),)
93+
circle_model = tico.convert(quantized_model, example_inputs)
94+
95+
# Save the Circle model
96+
filename = "quantized_vision_patch_embed.circle"
97+
circle_model.save(filename)
98+
print(f"Circle model saved as '{filename}'")
99+
100+
101+
if __name__ == "__main__":
102+
main()

0 commit comments

Comments
 (0)