Skip to content

Commit fd824f2

Browse files
author
d.savchenkov
committed
[quantization] Introduce wrapper for Qwen3VLTextModel
This change introduces QuantQwen3VLTextModel wrapper to support post-training quantization of Qwen3VLTextModel operation. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent 3c33243 commit fd824f2

File tree

7 files changed

+1462
-0
lines changed

7 files changed

+1462
-0
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pathlib
16+
import tempfile
17+
import unittest
18+
import warnings
19+
20+
import tico
21+
22+
import torch
23+
from tico.quantization.config.ptq import PTQConfig
24+
from tico.quantization.wrapq.dtypes import DType
25+
from tico.quantization.wrapq.mode import Mode
26+
from tico.quantization.wrapq.utils.version import has_transformers_for
27+
from tico.quantization.wrapq.wrappers.nn.quant_layernorm import QuantLayerNorm
28+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import (
29+
QuantQwen3VLTextDecoderLayer,
30+
)
31+
32+
33+
skip_msg = (
34+
"required transformers not installed — skipping Qwen3VLTextDecoderLayer tests"
35+
)
36+
37+
38+
@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg)
39+
class TestQuantQwen3VLTextDecoderLayer(unittest.TestCase):
40+
fp_model: torch.nn.Module
41+
hidden_size: int
42+
num_attention_heads: int
43+
head_dim: int
44+
45+
@classmethod
46+
def setUpClass(cls):
47+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
48+
Qwen3VLTextConfig,
49+
)
50+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
51+
Qwen3VLTextDecoderLayer,
52+
)
53+
54+
# Use smaller sizes for testing
55+
cfg = Qwen3VLTextConfig(
56+
hidden_size=64,
57+
num_attention_heads=2,
58+
num_key_value_heads=2,
59+
head_dim=32,
60+
max_position_embeddings=2048,
61+
intermediate_size=1024,
62+
)
63+
64+
# Ensure eager attention implementation so outputs are deterministic
65+
# and do not require GPU flash attention kernels.
66+
# Some versions use `_attn_implementation`, others expose `attn_implementation`.
67+
if not hasattr(cfg, "_attn_implementation"):
68+
setattr(cfg, "_attn_implementation", "eager")
69+
else:
70+
cfg._attn_implementation = "eager"
71+
72+
cls.fp_model = Qwen3VLTextDecoderLayer(cfg, layer_idx=0)
73+
cls.hidden_size = cfg.hidden_size
74+
cls.num_attention_heads = cfg.num_attention_heads
75+
cls.head_dim = cls.hidden_size // cls.num_attention_heads
76+
77+
def _rand_position_embeddings(self, batch_size, seq_len):
78+
"""Helper to create dummy rotary position embeddings"""
79+
cos = torch.randn(batch_size, seq_len, self.head_dim)
80+
sin = torch.randn(batch_size, seq_len, self.head_dim)
81+
return cos, sin
82+
83+
def _create_test_inputs(self, batch_size=2, seq_len=16):
84+
"""Helper to create test inputs for TextDecoderLayer."""
85+
hidden_states = torch.randn(batch_size, seq_len, self.hidden_size)
86+
position_embeddings = self._rand_position_embeddings(batch_size, seq_len)
87+
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
88+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
89+
return hidden_states, position_embeddings, attention_mask, position_ids
90+
91+
def test_mode_transitions(self):
92+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
93+
94+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
95+
self.assertIs(q_model._mode, Mode.NO_QUANT)
96+
97+
q_model.enable_calibration()
98+
self.assertIs(q_model._mode, Mode.CALIB)
99+
100+
# Run forward pass during calibration
101+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs()
102+
_ = q_model(
103+
hidden_states=hidden_states,
104+
position_embeddings=pos_emb,
105+
attention_mask=attn_mask,
106+
position_ids=pos_ids,
107+
)
108+
109+
q_model.freeze_qparams()
110+
self.assertIs(q_model._mode, Mode.QUANT)
111+
112+
def test_forward_diff(self):
113+
"""
114+
Test that quantized output is acceptably close to FP32 reference.
115+
"""
116+
torch.manual_seed(42)
117+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
118+
q_model.enable_calibration()
119+
120+
# Calibrate with multiple inputs
121+
for _ in range(4):
122+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs()
123+
_ = q_model(
124+
hidden_states=hidden_states,
125+
position_embeddings=pos_emb,
126+
attention_mask=attn_mask,
127+
position_ids=pos_ids,
128+
)
129+
130+
q_model.freeze_qparams()
131+
132+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs()
133+
with torch.no_grad():
134+
q_out = q_model(
135+
hidden_states=hidden_states,
136+
position_embeddings=pos_emb,
137+
attention_mask=attn_mask,
138+
position_ids=pos_ids,
139+
)
140+
fp_out = self.fp_model(
141+
hidden_states=hidden_states,
142+
position_embeddings=pos_emb,
143+
attention_mask=attn_mask,
144+
position_ids=pos_ids,
145+
)
146+
147+
self.assertEqual(fp_out.shape, q_out.shape)
148+
diff = (fp_out - q_out).abs().mean().item()
149+
self.assertGreater(diff, 0.0) # not identical
150+
self.assertLess(diff, 0.7) # acceptably close
151+
152+
def test_registration_in_registry(self):
153+
"""
154+
Test that Qwen3VLTextDecoderLayer is properly registered.
155+
"""
156+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_decoder_layer import (
157+
QuantQwen3VLTextDecoderLayer,
158+
)
159+
from tico.quantization.wrapq.wrappers.registry import lookup
160+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
161+
Qwen3VLTextDecoderLayer,
162+
)
163+
164+
wrapper_cls = lookup(Qwen3VLTextDecoderLayer)
165+
self.assertIs(wrapper_cls, QuantQwen3VLTextDecoderLayer)
166+
167+
def test_output_shape(self):
168+
"""
169+
Test that output shape is preserved.
170+
Input: (batch_size, seq_len, hidden_size)
171+
Output: (batch_size, seq_len, hidden_size)
172+
"""
173+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
174+
q_model.enable_calibration()
175+
176+
batch_size = 2
177+
seq_len = 16
178+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs(
179+
batch_size, seq_len
180+
)
181+
_ = q_model(
182+
hidden_states=hidden_states,
183+
position_embeddings=pos_emb,
184+
attention_mask=attn_mask,
185+
position_ids=pos_ids,
186+
)
187+
188+
q_model.freeze_qparams()
189+
190+
with torch.no_grad():
191+
q_out = q_model(
192+
hidden_states=hidden_states,
193+
position_embeddings=pos_emb,
194+
attention_mask=attn_mask,
195+
position_ids=pos_ids,
196+
)
197+
fp_out = self.fp_model(
198+
hidden_states=hidden_states,
199+
position_embeddings=pos_emb,
200+
attention_mask=attn_mask,
201+
position_ids=pos_ids,
202+
)
203+
204+
expected_shape = (batch_size, seq_len, self.hidden_size)
205+
self.assertEqual(q_out.shape, expected_shape)
206+
self.assertEqual(fp_out.shape, expected_shape)
207+
208+
def test_residual_connection_preservation(self):
209+
"""
210+
Test that residual connections are preserved (output close to input + transformation).
211+
"""
212+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
213+
q_model.enable_calibration()
214+
215+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs()
216+
_ = q_model(
217+
hidden_states=hidden_states,
218+
position_embeddings=pos_emb,
219+
attention_mask=attn_mask,
220+
position_ids=pos_ids,
221+
)
222+
223+
q_model.freeze_qparams()
224+
225+
with torch.no_grad():
226+
# Save input
227+
input_copy = hidden_states.clone()
228+
229+
# Run forward pass
230+
output = q_model(
231+
hidden_states=hidden_states,
232+
position_embeddings=pos_emb,
233+
attention_mask=attn_mask,
234+
position_ids=pos_ids,
235+
)
236+
237+
# Output should be different from input (transformation applied)
238+
self.assertFalse(torch.equal(output, input_copy))
239+
240+
# But shape should be preserved
241+
self.assertEqual(output.shape, input_copy.shape)
242+
243+
def test_observer_count(self):
244+
"""
245+
Test that the wrapper has the correct number of observers.
246+
- 3 local observers (input, post_attn, output)
247+
"""
248+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
249+
observers = list(q_model._all_observers())
250+
# Should have 3 local observers
251+
self.assertEqual(len(observers), 3)
252+
253+
def test_per_module_override(self):
254+
"""
255+
Test that PTQConfig overrides propagate correctly to submodules.
256+
"""
257+
cfg = PTQConfig(
258+
default_dtype=DType.uint(8),
259+
overrides={
260+
"self_attn": {
261+
"act_in": {"dtype": DType.uint(4)},
262+
}
263+
},
264+
)
265+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model, qcfg=cfg)
266+
267+
# Check that override is applied to local observer
268+
self.assertEqual(q_model.obs_act_in.dtype, DType.uint(8))
269+
270+
def test_different_batch_sizes(self):
271+
"""
272+
Test that quantization works correctly with different batch sizes.
273+
"""
274+
q_model = QuantQwen3VLTextDecoderLayer(self.fp_model)
275+
q_model.enable_calibration()
276+
277+
# Calibrate with one batch size
278+
calibrate_hidden, pos_emb, attn_mask, pos_ids = self._create_test_inputs(
279+
batch_size=2
280+
)
281+
for _ in range(3):
282+
_ = q_model(
283+
hidden_states=calibrate_hidden,
284+
position_embeddings=pos_emb,
285+
attention_mask=attn_mask,
286+
position_ids=pos_ids,
287+
)
288+
q_model.freeze_qparams()
289+
290+
# Test with different batch sizes
291+
for batch_size in [1, 2, 4]:
292+
hidden_states, pos_emb, attn_mask, pos_ids = self._create_test_inputs(
293+
batch_size=batch_size
294+
)
295+
with torch.no_grad():
296+
q_out = q_model(
297+
hidden_states=hidden_states,
298+
position_embeddings=pos_emb,
299+
attention_mask=attn_mask,
300+
position_ids=pos_ids,
301+
)
302+
fp_out = self.fp_model(
303+
hidden_states=hidden_states,
304+
position_embeddings=pos_emb,
305+
attention_mask=attn_mask,
306+
position_ids=pos_ids,
307+
)
308+
309+
self.assertEqual(q_out.shape, fp_out.shape)
310+
diff = (fp_out - q_out).abs().mean().item()
311+
self.assertLess(diff, 0.8)

0 commit comments

Comments
 (0)