Skip to content

Commit 45bb0ff

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Add test function for the group wise lut quantization
Differential Revision: D79120101
1 parent 49649a4 commit 45bb0ff

File tree

1 file changed

+168
-0
lines changed

1 file changed

+168
-0
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
from parameterized import param, parameterized
14+
from torch import uint1, uint2, uint3, uint4
15+
from torchao.prototype.quantization.codebook_groupwise.api import (
16+
GroupwiseLutWeightConfig,
17+
)
18+
from torchao.quantization.quant_api import quantize_
19+
from torchao.quantization.granularity import PerGroup
20+
21+
22+
class TestGroupwiseLowbitWeightLut(unittest.TestCase):
23+
"""
24+
Test suite for the GroupwiseLutWeight quantization scheme, updated for the
25+
new simplified API.
26+
"""
27+
28+
TEST_CASES = [
29+
param(
30+
weight_dtype=weight_dtype,
31+
lut_group_size=lut_group_size,
32+
scale_group_size=scale_group_size,
33+
model_dtype=model_dtype,
34+
has_bias=has_bias,
35+
has_scales=has_scales,
36+
)
37+
for weight_dtype in [uint1, uint2, uint3, uint4]
38+
for lut_group_size, scale_group_size in [(256, 64), (256, 32)]
39+
for model_dtype in [torch.float32]
40+
for has_bias in [True, False]
41+
for has_scales in [True, False]
42+
]
43+
44+
# --------------------------------------------------------------------------
45+
# Test 1: End-to-End Model Accuracy
46+
# --------------------------------------------------------------------------
47+
@parameterized.expand(TEST_CASES)
48+
def test_e2e_accuracy_vs_reference(
49+
self,
50+
weight_dtype,
51+
lut_group_size,
52+
scale_group_size,
53+
model_dtype,
54+
has_bias,
55+
has_scales,
56+
):
57+
"""
58+
Tests the numerical accuracy of the full quantized model against a reference.
59+
This now uses the `use_qdq_reference` flag instead of layout objects.
60+
"""
61+
m, k, n = 3, 64, 32
62+
activations = torch.randn(m, k, dtype=model_dtype)
63+
model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype)
64+
65+
lut_granularity = PerGroup(lut_group_size)
66+
scale_granularity = PerGroup(scale_group_size) if has_scales else None
67+
68+
# --- Quantize using C++ ops ---
69+
quantized_model = copy.deepcopy(model)
70+
perf_config = GroupwiseLutWeightConfig(
71+
weight_dtype=weight_dtype,
72+
lut_granularity=lut_granularity,
73+
scale_granularity=scale_granularity,
74+
use_qdq_reference=False, # This creates the custom tensor
75+
)
76+
quantize_(quantized_model, perf_config)
77+
with torch.no_grad():
78+
actual_result = quantized_model(activations)
79+
80+
# --- Quantize for Reference (using Python ops) ---
81+
reference_model = copy.deepcopy(model)
82+
ref_config = GroupwiseLutWeightConfig(
83+
weight_dtype=weight_dtype,
84+
lut_granularity=lut_granularity,
85+
scale_granularity=scale_granularity,
86+
use_qdq_reference=True,
87+
)
88+
quantize_(reference_model, ref_config)
89+
with torch.no_grad():
90+
expected_result = reference_model(activations)
91+
# Compare results
92+
self.assertTrue(
93+
torch.allclose(actual_result, expected_result, atol=1e-2, rtol=1e-2)
94+
)
95+
96+
def tearDown(self):
97+
"""
98+
Clear the TorchDynamo cache after each test case to prevent
99+
recompilation errors in parameterized tests.
100+
"""
101+
super().tearDown()
102+
torch._dynamo.reset()
103+
104+
# --------------------------------------------------------------------------
105+
# Test 2: Deployment Readiness (Updated for new API)
106+
# --------------------------------------------------------------------------
107+
@parameterized.expand(TEST_CASES)
108+
def test_export_compile_aoti(
109+
self,
110+
weight_dtype,
111+
lut_group_size,
112+
scale_group_size,
113+
model_dtype,
114+
has_bias,
115+
has_scales,
116+
):
117+
"""
118+
Tests that the quantized model can be exported and compiled.
119+
"""
120+
k, n = 64, 32
121+
activations = torch.randn(2, k, dtype=model_dtype)
122+
model = (
123+
nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype).eval()
124+
)
125+
126+
# Configure the quantization using the new API
127+
config = GroupwiseLutWeightConfig(
128+
weight_dtype=weight_dtype,
129+
lut_granularity=PerGroup(lut_group_size),
130+
scale_granularity=PerGroup(scale_group_size) if has_scales else None,
131+
use_qdq_reference=False, # Ensure we are testing the custom tensor
132+
)
133+
quantize_(model, config)
134+
135+
with torch.no_grad():
136+
eager_results = model(activations)
137+
138+
# Export and Compile
139+
exported_model = torch.export.export(model, (activations,))
140+
compiled_model = torch.compile(model, fullgraph=True)
141+
142+
with tempfile.TemporaryDirectory() as tmpdir, torch.no_grad():
143+
# Check exported model
144+
exported_results = exported_model.module()(activations)
145+
self.assertTrue(
146+
torch.allclose(eager_results, exported_results, atol=1e-3, rtol=1e-3)
147+
)
148+
149+
# Check compiled model
150+
compiled_results = compiled_model(activations)
151+
self.assertTrue(
152+
torch.allclose(eager_results, compiled_results, atol=1e-3, rtol=1e-3)
153+
)
154+
155+
# Check AOTI compiled model using the packaging API
156+
package_path = f"{tmpdir}/model.pt2"
157+
torch._inductor.aoti_compile_and_package(
158+
exported_model, package_path=package_path
159+
)
160+
aoti_model = torch._inductor.aoti_load_package(package_path)
161+
aoti_results = aoti_model(activations)
162+
self.assertTrue(
163+
torch.allclose(eager_results, aoti_results, atol=1e-3, rtol=1e-3)
164+
)
165+
166+
167+
if __name__ == "__main__":
168+
unittest.main()

0 commit comments

Comments
 (0)