Skip to content

Commit 60f8852

Browse files
committed
Make AWQ more general
Summary: * Added AWQConfig that takes a base config and made corresponding changes in other parts of the flow Test Plan: Tested on Phi4-mini and Qwen3-8B Qwen3-8B |Task | calibration_limit | no-awq | awq | |-----+------------------+ ------+ ------+ |leaderboard_math_hard (v3) | 2 | 0.3543 | 0.4371 | |gpqa_main_zeroshot | 50 | 0.32 | 0.36 | |mmlu | 5 | 0.7372 | 0.7463 | |bbh | 1 | 0.7385 | 0.7556| Phi4-mini | Task | calibration_limit | no-awq | awq | |------+------------------+--------+------| | mmlu_pro | 2 | 0.4057 | 0.4757 | | gsm8k | 5 | 0.72 | 0.76 | Reviewers: Subscribers: Tasks: Tags:
1 parent 2e2ce0b commit 60f8852

File tree

10 files changed

+470
-308
lines changed

10 files changed

+470
-308
lines changed

torchao/_models/_eval.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ def _model_call(self, inps):
5757

5858
max_seq_length = min(max(inps.size()), self.max_length)
5959
with torch.device(self._device):
60-
self._model.setup_caches(self.batch_size, max_seq_length)
60+
if hasattr(self._model, "setup_caches"):
61+
self._model.setup_caches(self.batch_size, max_seq_length)
6162
logits = self._model(*input)
63+
from transformers.modeling_outputs import CausalLMOutputWithPast
64+
65+
if isinstance(logits, CausalLMOutputWithPast):
66+
logits = logits.logits
6267
return logits
6368

6469
def run_eval(self, tasks, limit):
@@ -84,7 +89,11 @@ def eot_token_id(self):
8489
try:
8590
return self.tokenizer.eos_id()
8691
except:
87-
return self.tokenizer.eos_id
92+
try:
93+
return self.tokenizer.eos_id
94+
except:
95+
idx = self.tokenizer.all_special_tokens.index("<|endoftext|>")
96+
return self.tokenizer.all_special_ids[idx]
8897

8998
@property
9099
def max_length(self):
@@ -102,8 +111,8 @@ def batch_size(self):
102111
def device(self):
103112
return self._device
104113

105-
def tok_decode(self, tokens):
106-
decoded = self.tokenizer.decode(tokens)
114+
def tok_decode(self, tokens, **kwargs):
115+
decoded = self.tokenizer.decode(tokens, **kwargs)
107116
return decoded
108117

109118
def tok_encode(self, string: str, **kwargs):
@@ -115,8 +124,8 @@ def tok_encode(self, string: str, **kwargs):
115124
tokens = [self.tokenizer.bos_id] + tokens
116125
return tokens
117126

118-
def _model_generate(self, context, max_length, eos_token_id):
119-
raise Exception("unimplemented")
127+
# def _model_generate(self, context, max_length, stop, **generation_kwargs):
128+
# super()._model_generate(context, max_length, stop, **generation_kwargs)
120129

121130

122131
class LMEvalInputRecorder(TransformerEvalWrapper):

torchao/_models/llama/eval.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,87 @@ def run_evaluation(
237237
quantize_(
238238
model, codebook_weight_only(dtype=torch.uint4, scale_block_size=64)
239239
)
240+
elif quantization.startswith("awq-uintx"):
241+
from torchao._models._eval import TransformerEvalWrapper
242+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
243+
244+
if not TORCH_VERSION_AT_LEAST_2_3:
245+
print("Awq requires torch2.3+")
246+
exit()
247+
from torchao.prototype.awq import (
248+
AWQObservedLinear,
249+
awq_uintx,
250+
insert_awq_observer_,
251+
)
252+
253+
quant_dtype = quantization.split("-")[1]
254+
group_size = int(quantization.split("-")[2])
255+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
256+
model = model.to(device)
257+
# get calibration data
258+
insert_awq_observer_(
259+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
260+
)
261+
TransformerEvalWrapper(
262+
model=model.to(device),
263+
tokenizer=tokenizer,
264+
max_seq_length=256,
265+
input_prep_func=prepare_inputs_for_model,
266+
device=device,
267+
).run_eval(
268+
tasks=["wikitext"],
269+
limit=1,
270+
)
271+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
272+
use_hqq = "hqq" in quantization
273+
quantize_(
274+
model,
275+
awq_uintx(
276+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
277+
),
278+
is_observed_linear,
279+
)
280+
281+
elif quantization.startswith("awq-8da4w"):
282+
from torchao._models._eval import TransformerEvalWrapper
283+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
284+
285+
if not TORCH_VERSION_AT_LEAST_2_3:
286+
print("Awq requires torch2.3+")
287+
exit()
288+
from torchao.prototype.awq import (
289+
AWQObservedLinear,
290+
awq_uintx,
291+
insert_awq_observer_,
292+
)
293+
294+
quant_dtype = quantization.split("-")[1]
295+
group_size = int(quantization.split("-")[2])
296+
quant_dtype = getattr(torch, quant_dtype, torch.uint8)
297+
model = model.to(device)
298+
# get calibration data
299+
insert_awq_observer_(
300+
model, 1, 256, quant_dtype=quant_dtype, group_size=group_size
301+
)
302+
TransformerEvalWrapper(
303+
model=model.to(device),
304+
tokenizer=tokenizer,
305+
max_seq_length=256,
306+
input_prep_func=prepare_inputs_for_model,
307+
device=device,
308+
).run_eval(
309+
tasks=["wikitext"],
310+
limit=1,
311+
)
312+
is_observed_linear = lambda m, fqn: isinstance(m, AWQObservedLinear)
313+
use_hqq = "hqq" in quantization
314+
quantize_(
315+
model,
316+
awq_uintx(
317+
quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq
318+
),
319+
is_observed_linear,
320+
)
240321

241322
if compile:
242323
model = torch.compile(model, mode="max-autotune", fullgraph=True)

torchao/prototype/awq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from .api import awq_uintx, insert_awq_observer_
1+
from .api import AWQConfig, awq_uintx, insert_awq_observer_
22
from .core import AWQObservedLinear
33

44
__all__ = [
55
"awq_uintx",
66
"insert_awq_observer_",
77
"AWQObservedLinear",
8+
"AWQConfig",
89
]

0 commit comments

Comments
 (0)