Skip to content

Commit 7c91a04

Browse files
committed
Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags:
1 parent ffaf572 commit 7c91a04

File tree

13 files changed

+493
-520
lines changed

13 files changed

+493
-520
lines changed

test/core/test_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
config_from_dict,
2020
config_to_dict,
2121
)
22+
from torchao.prototype.awq import (
23+
AWQConfig,
24+
AWQStep,
25+
)
2226
from torchao.quantization.quant_api import (
2327
FbgemmConfig,
2428
Float8DynamicActivationFloat8WeightConfig,
@@ -79,6 +83,8 @@
7983
"linear2": Int8DynamicActivationInt4WeightConfig(),
8084
}
8185
),
86+
AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOADING),
87+
AWQConfig(Int4WeightOnlyConfig(group_size=128), step="prepare_for_loading"),
8288
]
8389

8490
if TORCH_VERSION_AT_LEAST_2_6:

test/prototype/test_awq.py

Lines changed: 192 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,26 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import os
7-
from copy import deepcopy
6+
import copy
7+
import tempfile
8+
import unittest
89

9-
import pytest
1010
import torch
11-
12-
from torchao.quantization import quantize_
13-
from torchao.testing.utils import skip_if_rocm
14-
from torchao.utils import (
15-
TORCH_VERSION_AT_LEAST_2_3,
16-
TORCH_VERSION_AT_LEAST_2_5,
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
1714
)
1815

19-
if TORCH_VERSION_AT_LEAST_2_3:
20-
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
16+
from torchao.prototype.awq import AWQConfig, AWQStep
17+
from torchao.quantization import FbgemmConfig, Int4WeightOnlyConfig, quantize_
2118

2219

2320
class ToyLinearModel(torch.nn.Module):
2421
def __init__(self, m=512, n=256, k=128):
2522
super().__init__()
2623
self.linear1 = torch.nn.Linear(m, n, bias=False)
2724
self.linear2 = torch.nn.Linear(n, k, bias=False)
28-
self.linear3 = torch.nn.Linear(k, 1, bias=False)
25+
self.linear3 = torch.nn.Linear(k, 64, bias=False)
2926

3027
def example_inputs(
3128
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
@@ -44,137 +41,189 @@ def forward(self, x):
4441
return x
4542

4643

47-
devices = ["cpu", "cuda"]
48-
# torch.uintx dtypes are introduced in 2.3
49-
if TORCH_VERSION_AT_LEAST_2_3:
50-
qdtypes = (torch.uint4, torch.uint7)
51-
else:
52-
qdtypes = ()
53-
54-
55-
@pytest.fixture(autouse=True)
56-
def run_before_and_after_tests():
57-
yield
58-
torch._dynamo.reset() # reset cache between tests
59-
60-
61-
@pytest.mark.parametrize("device", devices)
62-
@pytest.mark.parametrize("qdtype", qdtypes)
63-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
64-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
65-
@pytest.mark.skip("Temporarily skipping to unpin nightiles")
66-
def test_awq_loading(device, qdtype):
67-
if qdtype == torch.uint4 and device == "cpu":
68-
pytest.skip("uint4 not supported on cpu")
69-
70-
dataset_size = 100
71-
l1, l2, l3 = 512, 256, 128
72-
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
73-
quant_dtype = qdtype
74-
group_size = 128
75-
n_calibration_examples = 10
76-
n_validation_examples = 10
77-
sequence_length = 5
78-
79-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
80-
dataset = m.example_inputs(
81-
dataset_size,
82-
sequence_length=sequence_length,
83-
dtype=original_dtype,
84-
device=device,
85-
)
86-
calibration_data = dataset[:n_calibration_examples]
87-
88-
# calibrate
89-
insert_awq_observer_(
90-
m,
91-
n_validation_examples,
92-
sequence_length,
93-
quant_dtype=quant_dtype,
94-
group_size=group_size,
95-
)
96-
97-
for example in calibration_data:
98-
m(example.to(device))
99-
100-
# quantize
101-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
102-
quantize_(
103-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
104-
)
105-
106-
model_save_path = "awq_model.pth"
107-
torch.save(m, model_save_path)
108-
loaded_model = torch.load(model_save_path)
109-
os.remove(model_save_path)
110-
111-
if torch.cuda.is_available():
44+
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
45+
class TestAWQ(TestCase):
46+
def test_awq_config(self):
47+
base_config = Int4WeightOnlyConfig()
48+
AWQConfig(base_config, step=AWQStep.PREPARE)
49+
AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
50+
AWQConfig(base_config, step=AWQStep.CONVERT)
51+
52+
AWQConfig(base_config, step="prepare")
53+
AWQConfig(base_config, step="prepare_for_loading")
54+
AWQConfig(base_config, step="convert")
55+
56+
with self.assertRaisesRegex(ValueError, "is not one of"):
57+
AWQConfig(base_config, step="not_supported")
58+
59+
def test_awq_functionality(self):
60+
device = "cuda"
61+
dataset_size = 100
62+
l1, l2, l3 = 512, 256, 128
63+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
64+
group_size = 128
65+
n_calibration_examples = 10
66+
sequence_length = 5
67+
68+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
69+
70+
# baseline quantization
71+
base_config = FbgemmConfig(
72+
input_dtype=torch.bfloat16,
73+
weight_dtype=torch.int4,
74+
output_dtype=torch.bfloat16,
75+
block_size=[1, group_size],
76+
preshuffle=False,
77+
)
78+
m_baseline = copy.deepcopy(m)
79+
quantize_(m_baseline, base_config)
80+
81+
# awq quantization
82+
dataset = m.example_inputs(
83+
dataset_size,
84+
sequence_length=sequence_length,
85+
dtype=original_dtype,
86+
device=device,
87+
)
88+
ref_out = torch.cat([m(d.squeeze(0)) for d in dataset])
89+
90+
calibration_data = dataset[:n_calibration_examples]
91+
92+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
93+
quantize_(m, quant_config)
94+
95+
for example in calibration_data:
96+
m(example)
97+
98+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
99+
quantize_(m, quant_config)
100+
101+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
102+
baseline_out = torch.cat([m_baseline(d.squeeze(0)) for d in dataset])
103+
104+
loss_awq = (ref_out - awq_out).pow(2).mean().item()
105+
loss_base = (ref_out - baseline_out).pow(2).mean().item()
106+
assert loss_awq < loss_base
107+
108+
def test_awq_loading(self):
109+
device = "cuda"
110+
dataset_size = 100
111+
l1, l2, l3 = 512, 256, 128
112+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
113+
group_size = 128
114+
n_calibration_examples = 10
115+
sequence_length = 5
116+
117+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
118+
dataset = m.example_inputs(
119+
dataset_size,
120+
sequence_length=sequence_length,
121+
dtype=original_dtype,
122+
device=device,
123+
)
124+
calibration_data = dataset[:n_calibration_examples]
125+
126+
# calibrate
127+
base_config = FbgemmConfig(
128+
input_dtype=torch.bfloat16,
129+
weight_dtype=torch.int4,
130+
output_dtype=torch.bfloat16,
131+
block_size=[1, group_size],
132+
preshuffle=False,
133+
)
134+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
135+
quantize_(m, quant_config)
136+
137+
for example in calibration_data:
138+
m(example)
139+
140+
# quantize
141+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
142+
quantize_(m, quant_config)
143+
144+
with tempfile.NamedTemporaryFile() as f:
145+
torch.save(m.state_dict(), f)
146+
f.seek(0)
147+
state_dict = torch.load(f)
148+
149+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
150+
loaded_model.load_state_dict(state_dict, assign=True)
151+
112152
m = torch.compile(m, fullgraph=True)
113153
loaded_model = torch.compile(loaded_model, fullgraph=True)
114154

115-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
116-
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])
117-
118-
assert awq_out is not None
119-
assert awq_save_load_out is not None
120-
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
121-
122-
123-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
124-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
125-
@skip_if_rocm("ROCm enablement in progress")
126-
def test_save_weights_only():
127-
dataset_size = 100
128-
l1, l2, l3 = 512, 256, 128
129-
original_dtype = torch.bfloat16
130-
quant_dtype = torch.uint4
131-
device = "cuda"
132-
group_size = 128
133-
n_calibration_examples = 10
134-
n_validation_examples = 10
135-
sequence_length = 5
136-
137-
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
138-
m2 = deepcopy(m)
139-
dataset = m.example_inputs(
140-
dataset_size,
141-
sequence_length=sequence_length,
142-
dtype=original_dtype,
143-
device=device,
144-
)
145-
calibration_data = dataset[:n_calibration_examples]
146-
147-
# calibrate
148-
insert_awq_observer_(
149-
m,
150-
n_validation_examples,
151-
sequence_length,
152-
quant_dtype=quant_dtype,
153-
group_size=group_size,
154-
)
155-
156-
for example in calibration_data:
157-
m(example.to(device))
158-
159-
# quantize
160-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
161-
quantize_(
162-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
163-
)
164-
165-
model_save_path = "awq_model.pth"
166-
torch.save(m.state_dict(), model_save_path)
167-
m2.load_state_dict(
168-
torch.load(model_save_path), assign=True
169-
) # load weights only.torch.load(model_save_path)
170-
os.remove(model_save_path)
171-
172-
m = torch.compile(m, fullgraph=True)
173-
m2 = torch.compile(m2, fullgraph=True)
174-
175-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
176-
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])
177-
178-
assert awq_out is not None
179-
assert awq_save_load_out is not None
180-
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
155+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
156+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
157+
158+
assert awq_out is not None
159+
assert awq_save_load_out is not None
160+
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
161+
162+
def test_awq_loading_vllm(self):
163+
"""Simulate weight loading in vllm:
164+
* prepare model weight to the same format (awq weight)
165+
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
166+
167+
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
168+
"""
169+
device = "cuda"
170+
dataset_size = 100
171+
l1, l2, l3 = 512, 256, 128
172+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
173+
group_size = 128
174+
n_calibration_examples = 10
175+
sequence_length = 5
176+
177+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
178+
dataset = m.example_inputs(
179+
dataset_size,
180+
sequence_length=sequence_length,
181+
dtype=original_dtype,
182+
device=device,
183+
)
184+
calibration_data = dataset[:n_calibration_examples]
185+
186+
# calibrate
187+
base_config = FbgemmConfig(
188+
input_dtype=torch.bfloat16,
189+
weight_dtype=torch.int4,
190+
output_dtype=torch.bfloat16,
191+
block_size=[1, group_size],
192+
preshuffle=False,
193+
)
194+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
195+
quantize_(m, quant_config)
196+
197+
for example in calibration_data:
198+
m(example)
199+
200+
# quantize
201+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
202+
quantize_(m, quant_config)
203+
204+
with tempfile.NamedTemporaryFile() as f:
205+
torch.save(m.state_dict(), f)
206+
f.seek(0)
207+
state_dict = torch.load(f)
208+
209+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
210+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
211+
quantize_(loaded_model, quant_config)
212+
213+
loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
214+
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
215+
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])
216+
217+
m = torch.compile(m, fullgraph=True)
218+
loaded_model = torch.compile(loaded_model, fullgraph=True)
219+
220+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
221+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
222+
223+
assert awq_out is not None
224+
assert awq_save_load_out is not None
225+
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
226+
227+
228+
if __name__ == "__main__":
229+
run_tests()

0 commit comments

Comments
 (0)