Skip to content

Commit 5c23096

Browse files
szyszyzysfacebook-github-bot
authored andcommitted
Move codebook (LUT) generation methods into common utils.
Differential Revision: D79595460
1 parent 45bb0ff commit 5c23096

File tree

5 files changed

+645
-542
lines changed

5 files changed

+645
-542
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,8 +728,8 @@ def bool_to_on_off(value):
728728

729729

730730
# Only check submodules if we're going to build C++ extensions
731-
if use_cpp != "0":
732-
check_submodules()
731+
# if use_cpp != "0":
732+
# check_submodules()
733733

734734
setup(
735735
name="torchao",

torchao/experimental/tests/test_groupwise_lowbit_weight_lut_quantizer.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
GroupwiseLutWeightConfig,
1717
)
1818
from torchao.quantization.quant_api import quantize_
19-
from torchao.quantization.granularity import PerGroup
20-
19+
from torchao.prototype.quantization.codebook_utils.codebook_utils import (
20+
group_size_to_block_shapes
21+
)
2122

2223
class TestGroupwiseLowbitWeightLut(unittest.TestCase):
2324
"""
@@ -27,30 +28,29 @@ class TestGroupwiseLowbitWeightLut(unittest.TestCase):
2728

2829
TEST_CASES = [
2930
param(
30-
weight_dtype=weight_dtype,
31+
code_dtype=code_dtype,
3132
lut_group_size=lut_group_size,
3233
scale_group_size=scale_group_size,
33-
model_dtype=model_dtype,
34+
weight_dtype=weight_dtype,
3435
has_bias=has_bias,
3536
has_scales=has_scales,
3637
)
37-
for weight_dtype in [uint1, uint2, uint3, uint4]
38+
for code_dtype in [uint1, uint2, uint3, uint4]
3839
for lut_group_size, scale_group_size in [(256, 64), (256, 32)]
39-
for model_dtype in [torch.float32]
40+
for weight_dtype in [torch.float32]
4041
for has_bias in [True, False]
4142
for has_scales in [True, False]
4243
]
43-
4444
# --------------------------------------------------------------------------
4545
# Test 1: End-to-End Model Accuracy
4646
# --------------------------------------------------------------------------
4747
@parameterized.expand(TEST_CASES)
4848
def test_e2e_accuracy_vs_reference(
4949
self,
50-
weight_dtype,
50+
code_dtype,
5151
lut_group_size,
5252
scale_group_size,
53-
model_dtype,
53+
weight_dtype,
5454
has_bias,
5555
has_scales,
5656
):
@@ -59,19 +59,20 @@ def test_e2e_accuracy_vs_reference(
5959
This now uses the `use_qdq_reference` flag instead of layout objects.
6060
"""
6161
m, k, n = 3, 64, 32
62-
activations = torch.randn(m, k, dtype=model_dtype)
63-
model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype)
62+
activations = torch.randn(m, k, dtype=weight_dtype)
63+
model = nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype)
6464

65-
lut_granularity = PerGroup(lut_group_size)
66-
scale_granularity = PerGroup(scale_group_size) if has_scales else None
65+
lut_block_shape, scale_block_shape = group_size_to_block_shapes(lut_group_size=lut_group_size, tensor_shape=(n, k), scale_group_size=scale_group_size if has_scales else None)
6766

6867
# --- Quantize using C++ ops ---
6968
quantized_model = copy.deepcopy(model)
7069
perf_config = GroupwiseLutWeightConfig(
70+
code_dtype=code_dtype,
7171
weight_dtype=weight_dtype,
72-
lut_granularity=lut_granularity,
73-
scale_granularity=scale_granularity,
74-
use_qdq_reference=False, # This creates the custom tensor
72+
lut_block_shape=lut_block_shape,
73+
scale_block_shape=scale_block_shape,
74+
use_qdq_reference=False,
75+
has_scale=has_scales,
7576
)
7677
quantize_(quantized_model, perf_config)
7778
with torch.no_grad():
@@ -80,10 +81,12 @@ def test_e2e_accuracy_vs_reference(
8081
# --- Quantize for Reference (using Python ops) ---
8182
reference_model = copy.deepcopy(model)
8283
ref_config = GroupwiseLutWeightConfig(
84+
code_dtype=code_dtype,
8385
weight_dtype=weight_dtype,
84-
lut_granularity=lut_granularity,
85-
scale_granularity=scale_granularity,
86+
lut_block_shape=lut_block_shape,
87+
scale_block_shape=scale_block_shape,
8688
use_qdq_reference=True,
89+
has_scale=has_scales,
8790
)
8891
quantize_(reference_model, ref_config)
8992
with torch.no_grad():
@@ -107,28 +110,31 @@ def tearDown(self):
107110
@parameterized.expand(TEST_CASES)
108111
def test_export_compile_aoti(
109112
self,
110-
weight_dtype,
113+
code_dtype,
111114
lut_group_size,
112115
scale_group_size,
113-
model_dtype,
116+
weight_dtype,
114117
has_bias,
115118
has_scales,
116119
):
117120
"""
118121
Tests that the quantized model can be exported and compiled.
119122
"""
120123
k, n = 64, 32
121-
activations = torch.randn(2, k, dtype=model_dtype)
124+
activations = torch.randn(2, k, dtype=weight_dtype)
122125
model = (
123-
nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=model_dtype).eval()
126+
nn.Sequential(nn.Linear(k, n, bias=has_bias)).to(dtype=weight_dtype).eval()
124127
)
128+
lut_block_shape, scale_block_shape = group_size_to_block_shapes(lut_group_size=lut_group_size, tensor_shape=(n, k), scale_group_size=scale_group_size if has_scales else None)
125129

126130
# Configure the quantization using the new API
127131
config = GroupwiseLutWeightConfig(
132+
code_dtype=code_dtype,
128133
weight_dtype=weight_dtype,
129-
lut_granularity=PerGroup(lut_group_size),
130-
scale_granularity=PerGroup(scale_group_size) if has_scales else None,
131-
use_qdq_reference=False, # Ensure we are testing the custom tensor
134+
lut_block_shape=lut_block_shape,
135+
scale_block_shape=scale_block_shape,
136+
use_qdq_reference=False,
137+
has_scale=has_scales,
132138
)
133139
quantize_(model, config)
134140

0 commit comments

Comments
 (0)