Skip to content

Commit 8afdb1f

Browse files
committed
Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags:
1 parent b6ef500 commit 8afdb1f

File tree

13 files changed

+407
-448
lines changed

13 files changed

+407
-448
lines changed

test/prototype/test_awq.py

Lines changed: 117 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,28 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import os
7-
from copy import deepcopy
6+
import copy
7+
import tempfile
88

99
import pytest
1010
import torch
1111

12-
from torchao.quantization import quantize_
13-
from torchao.testing.utils import skip_if_rocm
12+
from torchao.quantization import FbgemmConfig, quantize_
1413
from torchao.utils import (
1514
TORCH_VERSION_AT_LEAST_2_3,
1615
TORCH_VERSION_AT_LEAST_2_5,
1716
)
1817

1918
if TORCH_VERSION_AT_LEAST_2_3:
20-
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
19+
from torchao.prototype.awq import AWQConfig, AWQStep
2120

2221

2322
class ToyLinearModel(torch.nn.Module):
2423
def __init__(self, m=512, n=256, k=128):
2524
super().__init__()
2625
self.linear1 = torch.nn.Linear(m, n, bias=False)
2726
self.linear2 = torch.nn.Linear(n, k, bias=False)
28-
self.linear3 = torch.nn.Linear(k, 1, bias=False)
27+
self.linear3 = torch.nn.Linear(k, 64, bias=False)
2928

3029
def example_inputs(
3130
self, batch_size, sequence_length=10, dtype=torch.bfloat16, device="cuda"
@@ -44,36 +43,74 @@ def forward(self, x):
4443
return x
4544

4645

47-
devices = ["cpu", "cuda"]
48-
# torch.uintx dtypes are introduced in 2.3
49-
if TORCH_VERSION_AT_LEAST_2_3:
50-
qdtypes = (torch.uint4, torch.uint7)
51-
else:
52-
qdtypes = ()
53-
54-
5546
@pytest.fixture(autouse=True)
5647
def run_before_and_after_tests():
5748
yield
5849
torch._dynamo.reset() # reset cache between tests
5950

6051

61-
@pytest.mark.parametrize("device", devices)
62-
@pytest.mark.parametrize("qdtype", qdtypes)
6352
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6453
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
65-
@pytest.mark.skip("Temporarily skipping to unpin nightiles")
66-
def test_awq_loading(device, qdtype):
67-
if qdtype == torch.uint4 and device == "cpu":
68-
pytest.skip("uint4 not supported on cpu")
54+
def test_awq_functionality():
55+
device = "cuda"
56+
dataset_size = 100
57+
l1, l2, l3 = 512, 256, 128
58+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
59+
group_size = 128
60+
n_calibration_examples = 10
61+
sequence_length = 5
62+
63+
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
64+
65+
# baseline quantization
66+
base_config = FbgemmConfig(
67+
input_dtype=torch.bfloat16,
68+
weight_dtype=torch.int4,
69+
output_dtype=torch.bfloat16,
70+
block_size=[1, group_size],
71+
preshuffle=False,
72+
)
73+
m_baseline = copy.deepcopy(m)
74+
quantize_(m_baseline, base_config)
75+
76+
# awq quantization
77+
dataset = m.example_inputs(
78+
dataset_size,
79+
sequence_length=sequence_length,
80+
dtype=original_dtype,
81+
device=device,
82+
)
83+
ref_out = torch.cat([m(d.squeeze(0)) for d in dataset])
84+
85+
calibration_data = dataset[:n_calibration_examples]
6986

87+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
88+
quantize_(m, quant_config)
89+
90+
for example in calibration_data:
91+
print("device:", example.device)
92+
m(example)
93+
94+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
95+
quantize_(m, quant_config)
96+
97+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
98+
baseline_out = torch.cat([m_baseline(d.squeeze(0)) for d in dataset])
99+
100+
loss_awq = (ref_out - awq_out).pow(2).mean().item()
101+
loss_base = (ref_out - baseline_out).pow(2).mean().item()
102+
assert loss_awq < loss_base
103+
104+
105+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
106+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
107+
def test_awq_loading():
108+
device = "cuda"
70109
dataset_size = 100
71110
l1, l2, l3 = 512, 256, 128
72111
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
73-
quant_dtype = qdtype
74112
group_size = 128
75113
n_calibration_examples = 10
76-
n_validation_examples = 10
77114
sequence_length = 5
78115

79116
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
@@ -86,56 +123,60 @@ def test_awq_loading(device, qdtype):
86123
calibration_data = dataset[:n_calibration_examples]
87124

88125
# calibrate
89-
insert_awq_observer_(
90-
m,
91-
n_validation_examples,
92-
sequence_length,
93-
quant_dtype=quant_dtype,
94-
group_size=group_size,
126+
base_config = FbgemmConfig(
127+
input_dtype=torch.bfloat16,
128+
weight_dtype=torch.int4,
129+
output_dtype=torch.bfloat16,
130+
block_size=[1, group_size],
131+
preshuffle=False,
95132
)
133+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
134+
quantize_(m, quant_config)
96135

97136
for example in calibration_data:
98-
m(example.to(device))
137+
m(example)
99138

100139
# quantize
101-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
102-
quantize_(
103-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
104-
)
140+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
141+
quantize_(m, quant_config)
105142

106-
model_save_path = "awq_model.pth"
107-
torch.save(m, model_save_path)
108-
loaded_model = torch.load(model_save_path)
109-
os.remove(model_save_path)
143+
with tempfile.NamedTemporaryFile() as f:
144+
torch.save(m.state_dict(), f)
145+
f.seek(0)
146+
state_dict = torch.load(f)
110147

111-
if torch.cuda.is_available():
112-
m = torch.compile(m, fullgraph=True)
113-
loaded_model = torch.compile(loaded_model, fullgraph=True)
148+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
149+
loaded_model.load_state_dict(state_dict, assign=True)
114150

115-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
116-
awq_save_load_out = torch.cat([loaded_model(i.squeeze(0)) for i in dataset])
151+
m = torch.compile(m, fullgraph=True)
152+
loaded_model = torch.compile(loaded_model, fullgraph=True)
153+
154+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
155+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
117156

118157
assert awq_out is not None
119158
assert awq_save_load_out is not None
120159
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
121160

122161

123-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
124162
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
125-
@skip_if_rocm("ROCm enablement in progress")
126-
def test_save_weights_only():
163+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
164+
def test_awq_loading_vllm():
165+
"""Simulate weight loading in vllm:
166+
* prepare model weight to the same format (awq weight)
167+
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
168+
169+
There is also a slicing op that is ommitted here, overall e2e is tested in tests in vllm repo
170+
"""
171+
device = "cuda"
127172
dataset_size = 100
128173
l1, l2, l3 = 512, 256, 128
129-
original_dtype = torch.bfloat16
130-
quant_dtype = torch.uint4
131-
device = "cuda"
174+
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
132175
group_size = 128
133176
n_calibration_examples = 10
134-
n_validation_examples = 10
135177
sequence_length = 5
136178

137179
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
138-
m2 = deepcopy(m)
139180
dataset = m.example_inputs(
140181
dataset_size,
141182
sequence_length=sequence_length,
@@ -145,35 +186,41 @@ def test_save_weights_only():
145186
calibration_data = dataset[:n_calibration_examples]
146187

147188
# calibrate
148-
insert_awq_observer_(
149-
m,
150-
n_validation_examples,
151-
sequence_length,
152-
quant_dtype=quant_dtype,
153-
group_size=group_size,
189+
base_config = FbgemmConfig(
190+
input_dtype=torch.bfloat16,
191+
weight_dtype=torch.int4,
192+
output_dtype=torch.bfloat16,
193+
block_size=[1, group_size],
194+
preshuffle=False,
154195
)
196+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
197+
quantize_(m, quant_config)
155198

156199
for example in calibration_data:
157-
m(example.to(device))
200+
m(example)
158201

159202
# quantize
160-
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
161-
quantize_(
162-
m, awq_uintx(quant_dtype=quant_dtype, group_size=group_size), is_observed_linear
163-
)
203+
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
204+
quantize_(m, quant_config)
205+
206+
with tempfile.NamedTemporaryFile() as f:
207+
torch.save(m.state_dict(), f)
208+
f.seek(0)
209+
state_dict = torch.load(f)
210+
211+
loaded_model = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
212+
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
213+
quantize_(loaded_model, quant_config)
164214

165-
model_save_path = "awq_model.pth"
166-
torch.save(m.state_dict(), model_save_path)
167-
m2.load_state_dict(
168-
torch.load(model_save_path), assign=True
169-
) # load weights only.torch.load(model_save_path)
170-
os.remove(model_save_path)
215+
loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
216+
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
217+
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])
171218

172219
m = torch.compile(m, fullgraph=True)
173-
m2 = torch.compile(m2, fullgraph=True)
220+
loaded_model = torch.compile(loaded_model, fullgraph=True)
174221

175-
awq_out = torch.cat([m(i.squeeze(0)) for i in dataset])
176-
awq_save_load_out = torch.cat([m2(i.squeeze(0)) for i in dataset])
222+
awq_out = torch.cat([m(d.squeeze(0)) for d in dataset])
223+
awq_save_load_out = torch.cat([loaded_model(d.squeeze(0)) for d in dataset])
177224

178225
assert awq_out is not None
179226
assert awq_save_load_out is not None

test/quantization/test_config_serialization.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
config_from_dict,
2020
config_to_dict,
2121
)
22+
from torchao.prototype.awq import (
23+
AWQConfig,
24+
AWQStep,
25+
)
2226
from torchao.quantization.quant_api import (
2327
FbgemmConfig,
2428
Float8DynamicActivationFloat8WeightConfig,
@@ -79,6 +83,7 @@
7983
"linear2": Int8DynamicActivationInt4WeightConfig(),
8084
}
8185
),
86+
AWQConfig(Int4WeightOnlyConfig(group_size=128), step=AWQStep.PREPARE_FOR_LOAD),
8287
]
8388

8489
if TORCH_VERSION_AT_LEAST_2_6:

torchao/_models/_eval.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def _model_call(self, inps):
5757

5858
max_seq_length = min(max(inps.size()), self.max_length)
5959
with torch.device(self._device):
60-
self._model.setup_caches(self.batch_size, max_seq_length)
60+
if hasattr(self._model, "setup_caches"):
61+
self._model.setup_caches(self.batch_size, max_seq_length)
6162
logits = self._model(*input)
63+
from transformers.modeling_outputs import CausalLMOutputWithPast
64+
65+
if isinstance(logits, CausalLMOutputWithPast):
66+
logits = logits.logits
6267
return logits
6368

6469
def run_eval(self, tasks, limit):
@@ -84,7 +89,11 @@ def eot_token_id(self):
8489
try:
8590
return self.tokenizer.eos_id()
8691
except:
87-
return self.tokenizer.eos_id
92+
try:
93+
return self.tokenizer.eos_id
94+
except:
95+
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
96+
return self.tokenizer.all_special_ids[idx]
8897

8998
@property
9099
def max_length(self):
@@ -102,8 +111,8 @@ def batch_size(self):
102111
def device(self):
103112
return self._device
104113

105-
def tok_decode(self, tokens):
106-
decoded = self.tokenizer.decode(tokens)
114+
def tok_decode(self, tokens, **kwargs):
115+
decoded = self.tokenizer.decode(tokens, **kwargs)
107116
return decoded
108117

109118
def tok_encode(self, string: str, **kwargs):
@@ -115,9 +124,6 @@ def tok_encode(self, string: str, **kwargs):
115124
tokens = [self.tokenizer.bos_id] + tokens
116125
return tokens
117126

118-
def _model_generate(self, context, max_length, eos_token_id):
119-
raise Exception("unimplemented")
120-
121127

122128
class LMEvalInputRecorder(TransformerEvalWrapper):
123129
def __init__(

torchao/_models/llama/eval.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,46 @@ def run_evaluation(
237237
quantize_(
238238
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
239239
)
240+
elif quantization.startswith("awq-uintx"):
241+
from torchao._models._eval import TransformerEvalWrapper
242+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
243+
244+
if not TORCH_VERSION_AT_LEAST_2_3:
245+
print("Awq requires torch2.3+")
246+
exit()
247+
from torchao.prototype.awq import (
248+
AWQObservedLinear,
249+
awq_uintx,
250+
insert_awq_observer_,
251+
)
252+
253+
quant_dtype = quantization.split("-")[1]
254+
group_size = int(quantization.split("-")[2])
255+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
256+
model = model.to(device)
257+
# get calibration data
258+
insert_awq_observer_(
259+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
260+
)
261+
TransformerEvalWrapper(
262+
model=model.to(device),
263+
tokenizer=tokenizer,
264+
max_seq_length=256,
265+
input_prep_func=prepare_inputs_for_model,
266+
device=device,
267+
).run_eval(
268+
tasks=["wikitext"],
269+
limit=1,
270+
)
271+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
272+
use_hqq = "hqq" in quantization
273+
quantize_(
274+
model,
275+
awq_uintx(
276+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
277+
),
278+
is_observed_linear,
279+
)
240280

241281
if compile:
242282
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
191191
"torchao.prototype.quantization",
192192
"torchao.prototype.mx_formats",
193193
"torchao.dtypes",
194+
"torchao.prototype.awq",
194195
}
195196

196197

0 commit comments

Comments
 (0)