16
16
GroupwiseLutWeightConfig ,
17
17
)
18
18
from torchao .quantization .quant_api import quantize_
19
- from torchao .quantization .granularity import PerGroup
20
-
19
+ from torchao .prototype .quantization .codebook_utils .codebook_utils import (
20
+ group_size_to_block_shapes
21
+ )
21
22
22
23
class TestGroupwiseLowbitWeightLut (unittest .TestCase ):
23
24
"""
@@ -27,30 +28,29 @@ class TestGroupwiseLowbitWeightLut(unittest.TestCase):
27
28
28
29
TEST_CASES = [
29
30
param (
30
- weight_dtype = weight_dtype ,
31
+ code_dtype = code_dtype ,
31
32
lut_group_size = lut_group_size ,
32
33
scale_group_size = scale_group_size ,
33
- model_dtype = model_dtype ,
34
+ weight_dtype = weight_dtype ,
34
35
has_bias = has_bias ,
35
36
has_scales = has_scales ,
36
37
)
37
- for weight_dtype in [uint1 , uint2 , uint3 , uint4 ]
38
+ for code_dtype in [uint1 , uint2 , uint3 , uint4 ]
38
39
for lut_group_size , scale_group_size in [(256 , 64 ), (256 , 32 )]
39
- for model_dtype in [torch .float32 ]
40
+ for weight_dtype in [torch .float32 ]
40
41
for has_bias in [True , False ]
41
42
for has_scales in [True , False ]
42
43
]
43
-
44
44
# --------------------------------------------------------------------------
45
45
# Test 1: End-to-End Model Accuracy
46
46
# --------------------------------------------------------------------------
47
47
@parameterized .expand (TEST_CASES )
48
48
def test_e2e_accuracy_vs_reference (
49
49
self ,
50
- weight_dtype ,
50
+ code_dtype ,
51
51
lut_group_size ,
52
52
scale_group_size ,
53
- model_dtype ,
53
+ weight_dtype ,
54
54
has_bias ,
55
55
has_scales ,
56
56
):
@@ -59,19 +59,20 @@ def test_e2e_accuracy_vs_reference(
59
59
This now uses the `use_qdq_reference` flag instead of layout objects.
60
60
"""
61
61
m , k , n = 3 , 64 , 32
62
- activations = torch .randn (m , k , dtype = model_dtype )
63
- model = nn .Sequential (nn .Linear (k , n , bias = has_bias )).to (dtype = model_dtype )
62
+ activations = torch .randn (m , k , dtype = weight_dtype )
63
+ model = nn .Sequential (nn .Linear (k , n , bias = has_bias )).to (dtype = weight_dtype )
64
64
65
- lut_granularity = PerGroup (lut_group_size )
66
- scale_granularity = PerGroup (scale_group_size ) if has_scales else None
65
+ lut_block_shape , scale_block_shape = group_size_to_block_shapes (lut_group_size = lut_group_size , tensor_shape = (n , k ), scale_group_size = scale_group_size if has_scales else None )
67
66
68
67
# --- Quantize using C++ ops ---
69
68
quantized_model = copy .deepcopy (model )
70
69
perf_config = GroupwiseLutWeightConfig (
70
+ code_dtype = code_dtype ,
71
71
weight_dtype = weight_dtype ,
72
- lut_granularity = lut_granularity ,
73
- scale_granularity = scale_granularity ,
74
- use_qdq_reference = False , # This creates the custom tensor
72
+ lut_block_shape = lut_block_shape ,
73
+ scale_block_shape = scale_block_shape ,
74
+ use_qdq_reference = False ,
75
+ has_scale = has_scales ,
75
76
)
76
77
quantize_ (quantized_model , perf_config )
77
78
with torch .no_grad ():
@@ -80,10 +81,12 @@ def test_e2e_accuracy_vs_reference(
80
81
# --- Quantize for Reference (using Python ops) ---
81
82
reference_model = copy .deepcopy (model )
82
83
ref_config = GroupwiseLutWeightConfig (
84
+ code_dtype = code_dtype ,
83
85
weight_dtype = weight_dtype ,
84
- lut_granularity = lut_granularity ,
85
- scale_granularity = scale_granularity ,
86
+ lut_block_shape = lut_block_shape ,
87
+ scale_block_shape = scale_block_shape ,
86
88
use_qdq_reference = True ,
89
+ has_scale = has_scales ,
87
90
)
88
91
quantize_ (reference_model , ref_config )
89
92
with torch .no_grad ():
@@ -107,28 +110,31 @@ def tearDown(self):
107
110
@parameterized .expand (TEST_CASES )
108
111
def test_export_compile_aoti (
109
112
self ,
110
- weight_dtype ,
113
+ code_dtype ,
111
114
lut_group_size ,
112
115
scale_group_size ,
113
- model_dtype ,
116
+ weight_dtype ,
114
117
has_bias ,
115
118
has_scales ,
116
119
):
117
120
"""
118
121
Tests that the quantized model can be exported and compiled.
119
122
"""
120
123
k , n = 64 , 32
121
- activations = torch .randn (2 , k , dtype = model_dtype )
124
+ activations = torch .randn (2 , k , dtype = weight_dtype )
122
125
model = (
123
- nn .Sequential (nn .Linear (k , n , bias = has_bias )).to (dtype = model_dtype ).eval ()
126
+ nn .Sequential (nn .Linear (k , n , bias = has_bias )).to (dtype = weight_dtype ).eval ()
124
127
)
128
+ lut_block_shape , scale_block_shape = group_size_to_block_shapes (lut_group_size = lut_group_size , tensor_shape = (n , k ), scale_group_size = scale_group_size if has_scales else None )
125
129
126
130
# Configure the quantization using the new API
127
131
config = GroupwiseLutWeightConfig (
132
+ code_dtype = code_dtype ,
128
133
weight_dtype = weight_dtype ,
129
- lut_granularity = PerGroup (lut_group_size ),
130
- scale_granularity = PerGroup (scale_group_size ) if has_scales else None ,
131
- use_qdq_reference = False , # Ensure we are testing the custom tensor
134
+ lut_block_shape = lut_block_shape ,
135
+ scale_block_shape = scale_block_shape ,
136
+ use_qdq_reference = False ,
137
+ has_scale = has_scales ,
132
138
)
133
139
quantize_ (model , config )
134
140
0 commit comments