Skip to content

Commit 6a07e4b

Browse files
author
d.savchenkov
committed
[quantization] Introduce wrapper for Qwen3VLTextRotaryEmbedding
This change introduces QuantQwen3VLTextRotaryEmbedding wrapper to support post-training quantization of Qwen3VLTextRotaryEmbedding module. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent 6a244c2 commit 6a07e4b

File tree

4 files changed

+713
-0
lines changed

4 files changed

+713
-0
lines changed
Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
import unittest
17+
18+
import torch
19+
from tico.quantization.config.ptq import PTQConfig
20+
from tico.quantization.wrapq.dtypes import DType
21+
from tico.quantization.wrapq.mode import Mode
22+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding import (
23+
QuantQwen3VLTextRotaryEmbedding,
24+
)
25+
26+
27+
trans_spec = importlib.util.find_spec("transformers")
28+
skip_msg = "transformers not installed — skipping Qwen3VLTextRotaryEmbedding tests"
29+
30+
31+
@unittest.skipUnless(trans_spec, skip_msg)
32+
class TestQuantQwen3VLTextRotaryEmbedding(unittest.TestCase):
33+
fp_rope: torch.nn.Module
34+
hidden_size: int
35+
head_dim: int
36+
37+
@classmethod
38+
def setUpClass(cls):
39+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
40+
Qwen3VLTextConfig,
41+
)
42+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
43+
Qwen3VLTextRotaryEmbedding,
44+
)
45+
46+
# Use smaller config for testing
47+
cfg = Qwen3VLTextConfig(
48+
hidden_size=32, # Smaller for testing
49+
num_attention_heads=4,
50+
max_position_embeddings=512,
51+
)
52+
cls.fp_rope = Qwen3VLTextRotaryEmbedding(cfg)
53+
cls.hidden_size = cfg.hidden_size
54+
cls.head_dim = (
55+
getattr(cfg, "head_dim", None) or cfg.hidden_size // cfg.num_attention_heads
56+
) # 8
57+
58+
def test_mode_transitions(self):
59+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
60+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
61+
self.assertIs(q_rope._mode, Mode.NO_QUANT)
62+
63+
q_rope.enable_calibration()
64+
self.assertIs(q_rope._mode, Mode.CALIB)
65+
66+
# Run forward pass during calibration
67+
x = torch.randn(2, 64, self.head_dim)
68+
position_ids = torch.arange(64).unsqueeze(0).expand(2, -1)
69+
_ = q_rope(x, position_ids)
70+
71+
q_rope.freeze_qparams()
72+
self.assertIs(q_rope._mode, Mode.QUANT)
73+
74+
def test_quantised_output_close(self):
75+
"""
76+
Test that quantized outputs (cos, sin) are acceptably close to FP32 reference.
77+
"""
78+
torch.manual_seed(42)
79+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
80+
q_rope.enable_calibration()
81+
82+
# Calibrate with different sequence lengths
83+
for seq_len in [32, 64, 128]:
84+
x = torch.randn(2, seq_len, self.head_dim)
85+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
86+
_ = q_rope(x, position_ids)
87+
88+
q_rope.freeze_qparams()
89+
90+
seq_len = 64
91+
x = torch.randn(2, seq_len, self.head_dim)
92+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
93+
94+
with torch.no_grad():
95+
q_cos, q_sin = q_rope(x, position_ids)
96+
fp_cos, fp_sin = self.fp_rope(x, position_ids)
97+
98+
diff_cos = (fp_cos - q_cos).abs().mean().item()
99+
diff_sin = (fp_sin - q_sin).abs().mean().item()
100+
101+
self.assertGreater(diff_cos, 0.0) # not identical
102+
self.assertGreater(diff_sin, 0.0)
103+
self.assertLess(diff_cos, 0.4) # acceptably close
104+
self.assertLess(diff_sin, 0.4)
105+
self.assertEqual(fp_cos.shape, q_cos.shape)
106+
self.assertEqual(fp_sin.shape, q_sin.shape)
107+
108+
def test_output_shape(self):
109+
"""
110+
Test that output shapes are correct: (batch_size, seq_len, head_dim)
111+
"""
112+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
113+
q_rope.enable_calibration()
114+
115+
seq_len = 64
116+
x = torch.randn(2, seq_len, self.head_dim)
117+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
118+
_ = q_rope(x, position_ids)
119+
120+
q_rope.freeze_qparams()
121+
122+
with torch.no_grad():
123+
q_cos, q_sin = q_rope(x, position_ids)
124+
125+
expected_shape = (2, seq_len, self.head_dim)
126+
self.assertEqual(q_cos.shape, expected_shape)
127+
self.assertEqual(q_sin.shape, expected_shape)
128+
129+
def test_output_range(self):
130+
"""
131+
Test that cos and sin outputs are in valid range [-1, 1].
132+
"""
133+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
134+
q_rope.enable_calibration()
135+
136+
seq_len = 64
137+
x = torch.randn(2, seq_len, self.head_dim)
138+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
139+
_ = q_rope(x, position_ids)
140+
141+
q_rope.freeze_qparams()
142+
143+
with torch.no_grad():
144+
q_cos, q_sin = q_rope(x, position_ids)
145+
146+
# Check ranges (with some tolerance for quantization error)
147+
self.assertLessEqual(q_cos.max(), 1.01)
148+
self.assertGreaterEqual(q_cos.min(), -1.01)
149+
self.assertLessEqual(q_sin.max(), 1.01)
150+
self.assertGreaterEqual(q_sin.min(), -1.01)
151+
152+
def test_different_sequence_lengths(self):
153+
"""
154+
Test that quantization works correctly with different sequence lengths.
155+
Calibrate with maximum length to cover full range.
156+
"""
157+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
158+
q_rope.enable_calibration()
159+
160+
# Calibrate with MAXIMUM length
161+
max_seq_len = 256
162+
for _ in range(3):
163+
x = torch.randn(2, max_seq_len, self.head_dim)
164+
position_ids = torch.arange(max_seq_len).unsqueeze(0).expand(2, -1)
165+
_ = q_rope(x, position_ids)
166+
167+
q_rope.freeze_qparams()
168+
169+
# Test with different lengths
170+
for seq_len in [32, 64, 128, 256]:
171+
x = torch.randn(2, seq_len, self.head_dim)
172+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
173+
174+
with torch.no_grad():
175+
q_cos, q_sin = q_rope(x, position_ids)
176+
fp_cos, fp_sin = self.fp_rope(x, position_ids)
177+
178+
diff_cos = (fp_cos - q_cos).abs().mean().item()
179+
diff_sin = (fp_sin - q_sin).abs().mean().item()
180+
181+
self.assertLess(diff_cos, 0.4)
182+
self.assertLess(diff_sin, 0.4)
183+
self.assertEqual(q_cos.shape[0], 2)
184+
self.assertEqual(q_cos.shape[1], seq_len)
185+
self.assertEqual(q_cos.shape[2], self.head_dim)
186+
187+
def test_dtype_override(self):
188+
"""
189+
PTQConfig overrides should affect the observers.
190+
"""
191+
cfg = PTQConfig(
192+
default_dtype=DType.uint(8),
193+
overrides={
194+
"cos": {"dtype": DType.uint(4)},
195+
"sin": {"dtype": DType.uint(4)},
196+
},
197+
)
198+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope, qcfg=cfg)
199+
200+
self.assertEqual(q_rope.obs_cos.dtype, DType.uint(4))
201+
self.assertEqual(q_rope.obs_sin.dtype, DType.uint(4))
202+
203+
def test_activation_stats_collected(self):
204+
"""
205+
Test that activation statistics are properly collected during calibration.
206+
"""
207+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
208+
q_rope.enable_calibration()
209+
210+
# Run forward pass to collect stats
211+
seq_len = 64
212+
x = torch.randn(2, seq_len, self.head_dim)
213+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
214+
_ = q_rope(x, position_ids)
215+
216+
# Check that observers have collected stats
217+
self.assertTrue(
218+
q_rope.obs_cos.has_qparams or q_rope.obs_cos.min_val.numel() > 0
219+
)
220+
self.assertTrue(
221+
q_rope.obs_sin.has_qparams or q_rope.obs_sin.min_val.numel() > 0
222+
)
223+
224+
# Freeze and check qparams exist
225+
q_rope.freeze_qparams()
226+
self.assertTrue(q_rope.obs_cos.has_qparams)
227+
self.assertTrue(q_rope.obs_sin.has_qparams)
228+
229+
def test_observer_count(self):
230+
"""
231+
Test that the wrapper has the correct number of observers.
232+
6 observers: inv_freq, freqs, freqs_mrope, emb, cos, sin
233+
"""
234+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
235+
236+
observers = list(q_rope._all_observers())
237+
self.assertEqual(len(observers), 6)
238+
239+
def test_registration_in_registry(self):
240+
"""
241+
Test that Qwen3VLTextRotaryEmbedding is properly registered.
242+
"""
243+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding import (
244+
QuantQwen3VLTextRotaryEmbedding,
245+
)
246+
from tico.quantization.wrapq.wrappers.registry import lookup
247+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
248+
Qwen3VLTextRotaryEmbedding,
249+
)
250+
251+
wrapper_cls = lookup(Qwen3VLTextRotaryEmbedding)
252+
self.assertIs(wrapper_cls, QuantQwen3VLTextRotaryEmbedding)
253+
254+
def test_no_learnable_parameters(self):
255+
"""
256+
Test that the wrapper has no learnable parameters (only buffers).
257+
"""
258+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
259+
260+
# Check that there are no parameters
261+
params = list(q_rope.parameters())
262+
self.assertEqual(len(params), 0)
263+
264+
# Check that inv_freq is a buffer, not a parameter
265+
self.assertIsInstance(q_rope.inv_freq, torch.Tensor)
266+
self.assertIn("inv_freq", q_rope._buffers)
267+
268+
def test_cos_sin_relationship(self):
269+
"""
270+
Test that cos² + sin² = 1 (unit circle property).
271+
Quantization error should be small enough to preserve this property approximately.
272+
"""
273+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
274+
q_rope.enable_calibration()
275+
276+
seq_len = 64
277+
x = torch.randn(2, seq_len, self.head_dim)
278+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
279+
_ = q_rope(x, position_ids)
280+
281+
q_rope.freeze_qparams()
282+
283+
with torch.no_grad():
284+
q_cos, q_sin = q_rope(x, position_ids)
285+
286+
# Check unit circle property
287+
unit_circle = q_cos.pow(2) + q_sin.pow(2)
288+
# Allow some deviation due to quantization error
289+
self.assertGreaterEqual(unit_circle.min(), 0.95)
290+
self.assertLessEqual(unit_circle.max(), 1.05)
291+
292+
def test_different_batch_sizes(self):
293+
"""
294+
Test that quantization works correctly with different batch sizes.
295+
"""
296+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
297+
q_rope.enable_calibration()
298+
299+
seq_len = 64
300+
# Calibrate with batch size 2
301+
for _ in range(3):
302+
x = torch.randn(2, seq_len, self.head_dim)
303+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1)
304+
_ = q_rope(x, position_ids)
305+
306+
q_rope.freeze_qparams()
307+
308+
# Test with different batch sizes
309+
for batch_size in [1, 2, 4]:
310+
x = torch.randn(batch_size, seq_len, self.head_dim)
311+
position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1)
312+
313+
with torch.no_grad():
314+
q_cos, q_sin = q_rope(x, position_ids)
315+
fp_cos, fp_sin = self.fp_rope(x, position_ids)
316+
317+
diff_cos = (fp_cos - q_cos).abs().mean().item()
318+
diff_sin = (fp_sin - q_sin).abs().mean().item()
319+
320+
self.assertLess(diff_cos, 0.4)
321+
self.assertLess(diff_sin, 0.4)
322+
self.assertEqual(q_cos.shape[0], batch_size)
323+
324+
def test_mrope_semantic_equivalence(self):
325+
"""
326+
Test that QuantQwen3VLTextRotaryEmbedding.apply_interleaved_mrope produces identical output
327+
to the original Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope.
328+
"""
329+
torch.manual_seed(42)
330+
331+
# Create test freqs tensor
332+
batch_size = 2
333+
seq_len = 64
334+
head_dim = self.head_dim
335+
freqs = torch.randn(3, batch_size, seq_len, head_dim // 2)
336+
337+
# Call original implementation
338+
freqs_t_original = self.fp_rope.apply_interleaved_mrope(
339+
freqs, self.fp_rope.mrope_section
340+
)
341+
342+
# Call new implementation
343+
q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope)
344+
freqs_t_new = q_rope.apply_interleaved_mrope(freqs, q_rope.mrope_section)
345+
346+
# Compare outputs
347+
self.assertEqual(freqs_t_original.shape, freqs_t_new.shape)
348+
349+
# Check exact equality (should be identical)
350+
torch.testing.assert_close(
351+
freqs_t_original,
352+
freqs_t_new,
353+
rtol=1e-5,
354+
atol=1e-5,
355+
msg="MRoPE implementations produce different outputs",
356+
)
357+
358+
# Also check with different input shapes
359+
test_configs = [
360+
(1, 32, head_dim), # Single sample, shorter sequence
361+
(4, 128, head_dim), # Larger batch, longer sequence
362+
(2, 256, head_dim), # Very long sequence
363+
]
364+
365+
for bs, sl, hd in test_configs:
366+
freqs = torch.randn(3, bs, sl, hd // 2)
367+
368+
freqs_t_original = self.fp_rope.apply_interleaved_mrope(
369+
freqs, self.fp_rope.mrope_section
370+
)
371+
freqs_t_new = q_rope.apply_interleaved_mrope(freqs, q_rope.mrope_section)
372+
373+
self.assertEqual(freqs_t_original.shape, freqs_t_new.shape)
374+
self.assertTrue(
375+
torch.equal(freqs_t_original, freqs_t_new),
376+
f"MRoPE implementations differ for shape (3, {bs}, {sl}, {hd//2})",
377+
)

0 commit comments

Comments
 (0)