Skip to content

Commit 8b229a7

Browse files
committed
Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags:
1 parent ffaf572 commit 8b229a7

File tree

13 files changed

+505
-519
lines changed

13 files changed

+505
-519
lines changed

test/core/test_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
config_from_dict,
2020
config_to_dict,
2121
)
22+
from torchao.prototype.awq import (
23+
AWQConfig,
24+
AWQStep,
25+
)
2226
from torchao.quantization.quant_api import (
2327
FbgemmConfig,
2428
Float8DynamicActivationFloat8WeightConfig,
@@ -79,6 +83,8 @@
7983
"linear2": Int8DynamicActivationInt4WeightConfig(),
8084
}
8185
),
86+
AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING),
87+
AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"),
8288
]
8389

8490
if TORCH_VERSION_AT_LEAST_2_6:

test/prototype/test_awq.py

Lines changed: 203 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,30 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import os
7-
from copy import deepcopy
6+
import copy
7+
import tempfile
8+
import unittest
89

9-
import pytest
1010
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
1115

12-
from torchao.quantization import quantize_
13-
from torchao.testing.utils import skip_if_rocm
16+
from torchao.prototype.awq import AWQConfig, AWQStep
17+
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
1418
from torchao.utils import (
15-
TORCH_VERSION_AT_LEAST_2_3,
16-
TORCH_VERSION_AT_LEAST_2_5,
19+
TORCH_VERSION_AT_LEAST_2_6,
20+
_is_fbgemm_genai_gpu_available,
1721
)
1822

19-
if TORCH_VERSION_AT_LEAST_2_3:
20-
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
21-
2223

2324
class ToyLinearModel(torch.nn.Module):
2425
def __init__(self, m=512, n=256, k=128):
2526
super().__init__()
2627
self.linear1 = torch.nn.Linear(m, n, bias=False)
2728
self.linear2 = torch.nn.Linear(n, k, bias=False)
28-
self.linear3 = torch.nn.Linear(k, 1, bias=False)
29+
self.linear3 = torch.nn.Linear(k, 64, bias=False)
2930

3031
def example_inputs(
3132
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
@@ -44,137 +45,197 @@ def forward(self, x):
4445
return x
4546

4647

47-
devices = ["cpu", "cuda"]
48-
# torch.uintx dtypes are introduced in 2.3
49-
if TORCH_VERSION_AT_LEAST_2_3:
50-
qdtypes = (torch.uint4, torch.uint7)
51-
else:
52-
qdtypes = ()
53-
54-
55-
@pytest.fixture(autouse=True)
56-
def run_before_and_after_tests():
57-
yield
58-
torch._dynamo.reset() # reset cache between tests
59-
60-
61-
@pytest.mark.parametrize("device", devices)
62-
@pytest.mark.parametrize("qdtype", qdtypes)
63-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
64-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
65-
@pytest.mark.skip("Temporarily skipping to unpin nightiles")
66-
def test_awq_loading(device, qdtype):
67-
if qdtype == torch.uint4 and device == "cpu":
68-
pytest.skip("uint4 not supported on cpu")
69-
70-
dataset_size = 100
71-
l1, l2, l3 = 512, 256, 128
72-
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
73-
quant_dtype = qdtype
74-
group_size = 128
75-
n_calibration_examples = 10
76-
n_validation_examples = 10
77-
sequence_length = 5
78-
79-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
80-
dataset = m.example_inputs(
81-
dataset_size,
82-
sequence_length=sequence_length,
83-
dtype=original_dtype,
84-
device=device,
85-
)
86-
calibration_data = dataset[:n_calibration_examples]
87-
88-
# calibrate
89-
insert_awq_observer_(
90-
m,
91-
n_validation_examples,
92-
sequence_length,
93-
quant_dtype=quant_dtype,
94-
group_size=group_size,
95-
)
96-
97-
for example in calibration_data:
98-
m(example.to(device))
99-
100-
# quantize
101-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
102-
quantize_(
103-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
104-
)
105-
106-
model_save_path = "awq_model.pth"
107-
torch.save(m, model_save_path)
108-
loaded_model = torch.load(model_save_path)
109-
os.remove(model_save_path)
110-
111-
if torch.cuda.is_available():
48+
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
49+
@unittest.skipIf(
50+
not _is_fbgemm_genai_gpu_available(),
51+
reason="need to install fbgemm_gpu_genai package",
52+
)
53+
@unittest.skipIf(
54+
not TORCH_VERSION_AT_LEAST_2_6,
55+
reason="torch.int4 needs torch 2.6+, can remove after we are not using FbgemmConfig",
56+
)
57+
class TestAWQ(TestCase):
58+
def test_awq_config(self):
59+
base_config = Int4WeightOnlyConfig()
60+
AWQConfig(base_config, step=AWQStep.PREPARE)
61+
AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
62+
AWQConfig(base_config, step=AWQStep.CONVERT)
63+
64+
AWQConfig(base_config, step="prepare")
65+
AWQConfig(base_config, step="prepare_for_loading")
66+
AWQConfig(base_config, step="convert")
67+
68+
with self.assertRaisesRegex(ValueError, "is not one of"):
69+
AWQConfig(base_config, step="not_supported")
70+
71+
def test_awq_functionality(self):
72+
device = "cuda"
73+
dataset_size = 100
74+
l1, l2, l3 = 512, 256, 128
75+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
76+
group_size = 128
77+
n_calibration_examples = 10
78+
sequence_length = 5
79+
80+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
81+
82+
# baseline quantization
83+
base_config = FbgemmConfig(
84+
input_dtype=torch.bfloat16,
85+
weight_dtype=torch.int4,
86+
output_dtype=torch.bfloat16,
87+
block_size=[1, group_size],
88+
preshuffle=False,
89+
)
90+
m_baseline = copy.deepcopy(m)
91+
quantize_(m_baseline, base_config)
92+
93+
# awq quantization
94+
dataset = m.example_inputs(
95+
dataset_size,
96+
sequence_length=sequence_length,
97+
dtype=original_dtype,
98+
device=device,
99+
)
100+
ref_out = torch.cat([m(d.squeeze(0)) for d in dataset])
101+
102+
calibration_data = dataset[:n_calibration_examples]
103+
104+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
105+
quantize_(m, quant_config)
106+
107+
for example in calibration_data:
108+
m(example)
109+
110+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
111+
quantize_(m, quant_config)
112+
113+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
114+
baseline_out = torch.cat([m_baseline(d.squeeze(0)) for d in dataset])
115+
116+
loss_awq = (ref_out - awq_out).pow(2).mean().item()
117+
loss_base = (ref_out - baseline_out).pow(2).mean().item()
118+
assert loss_awq < loss_base
119+
120+
def test_awq_loading(self):
121+
device = "cuda"
122+
dataset_size = 100
123+
l1, l2, l3 = 512, 256, 128
124+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
125+
group_size = 128
126+
n_calibration_examples = 10
127+
sequence_length = 5
128+
129+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
130+
dataset = m.example_inputs(
131+
dataset_size,
132+
sequence_length=sequence_length,
133+
dtype=original_dtype,
134+
device=device,
135+
)
136+
calibration_data = dataset[:n_calibration_examples]
137+
138+
# calibrate
139+
base_config = FbgemmConfig(
140+
input_dtype=torch.bfloat16,
141+
weight_dtype=torch.int4,
142+
output_dtype=torch.bfloat16,
143+
block_size=[1, group_size],
144+
preshuffle=False,
145+
)
146+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
147+
quantize_(m, quant_config)
148+
149+
for example in calibration_data:
150+
m(example)
151+
152+
# quantize
153+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
154+
quantize_(m, quant_config)
155+
156+
with tempfile.NamedTemporaryFile() as f:
157+
torch.save(m.state_dict(), f)
158+
f.seek(0)
159+
state_dict = torch.load(f)
160+
161+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
162+
loaded_model.load_state_dict(state_dict, assign=True)
163+
164+
m = torch.compile(m, fullgraph=True)
165+
loaded_model = torch.compile(loaded_model, fullgraph=True)
166+
167+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
168+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
169+
170+
assert awq_out is not None
171+
assert awq_save_load_out is not None
172+
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
173+
174+
def test_awq_loading_vllm(self):
175+
"""Simulate weight loading in vllm:
176+
* prepare model weight to the same format (awq weight)
177+
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
178+
179+
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
180+
"""
181+
device = "cuda"
182+
dataset_size = 100
183+
l1, l2, l3 = 512, 256, 128
184+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
185+
group_size = 128
186+
n_calibration_examples = 10
187+
sequence_length = 5
188+
189+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
190+
dataset = m.example_inputs(
191+
dataset_size,
192+
sequence_length=sequence_length,
193+
dtype=original_dtype,
194+
device=device,
195+
)
196+
calibration_data = dataset[:n_calibration_examples]
197+
198+
# calibrate
199+
base_config = FbgemmConfig(
200+
input_dtype=torch.bfloat16,
201+
weight_dtype=torch.int4,
202+
output_dtype=torch.bfloat16,
203+
block_size=[1, group_size],
204+
preshuffle=False,
205+
)
206+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
207+
quantize_(m, quant_config)
208+
209+
for example in calibration_data:
210+
m(example)
211+
212+
# quantize
213+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
214+
quantize_(m, quant_config)
215+
216+
with tempfile.NamedTemporaryFile() as f:
217+
torch.save(m.state_dict(), f)
218+
f.seek(0)
219+
state_dict = torch.load(f)
220+
221+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
222+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
223+
quantize_(loaded_model, quant_config)
224+
225+
loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
226+
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
227+
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])
228+
112229
m = torch.compile(m, fullgraph=True)
113230
loaded_model = torch.compile(loaded_model, fullgraph=True)
114231

115-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
116-
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])
117-
118-
assert awq_out is not None
119-
assert awq_save_load_out is not None
120-
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
121-
122-
123-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
124-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
125-
@skip_if_rocm("ROCm enablement in progress")
126-
def test_save_weights_only():
127-
dataset_size = 100
128-
l1, l2, l3 = 512, 256, 128
129-
original_dtype = torch.bfloat16
130-
quant_dtype = torch.uint4
131-
device = "cuda"
132-
group_size = 128
133-
n_calibration_examples = 10
134-
n_validation_examples = 10
135-
sequence_length = 5
136-
137-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
138-
m2 = deepcopy(m)
139-
dataset = m.example_inputs(
140-
dataset_size,
141-
sequence_length=sequence_length,
142-
dtype=original_dtype,
143-
device=device,
144-
)
145-
calibration_data = dataset[:n_calibration_examples]
146-
147-
# calibrate
148-
insert_awq_observer_(
149-
m,
150-
n_validation_examples,
151-
sequence_length,
152-
quant_dtype=quant_dtype,
153-
group_size=group_size,
154-
)
155-
156-
for example in calibration_data:
157-
m(example.to(device))
158-
159-
# quantize
160-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
161-
quantize_(
162-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
163-
)
164-
165-
model_save_path = "awq_model.pth"
166-
torch.save(m.state_dict(), model_save_path)
167-
m2.load_state_dict(
168-
torch.load(model_save_path), assign=True
169-
) # load weights only.torch.load(model_save_path)
170-
os.remove(model_save_path)
171-
172-
m = torch.compile(m, fullgraph=True)
173-
m2 = torch.compile(m2, fullgraph=True)
174-
175-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
176-
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])
177-
178-
assert awq_out is not None
179-
assert awq_save_load_out is not None
180-
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
232+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
233+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
234+
235+
assert awq_out is not None
236+
assert awq_save_load_out is not None
237+
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
238+
239+
240+
if __name__ == "__main__":
241+
run_tests()

0 commit comments

Comments
 (0)