Skip to content

Commit 7bc824a

Browse files
committed
Add GPTQQuantizer
Summary: Implement GPTQQuantizer with the unified quantizer API Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 55e5d40 commit 7bc824a

File tree

3 files changed

+991
-3
lines changed

3 files changed

+991
-3
lines changed

test/quantization/test_quant_api.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
from torchao.quantization.quant_api import (
2323
Quantizer,
2424
TwoStepQuantizer,
25+
Int8DynActInt4WeightGPTQQuantizer,
2526
)
27+
from pathlib import Path
28+
from sentencepiece import SentencePieceProcessor
29+
2630

2731
def dynamic_quant(model, example_inputs):
2832
m = capture_pre_autograd_graph(model, example_inputs)
@@ -127,7 +131,31 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
127131

128132
def test_gptq(self):
129133
# should be similar to TorchCompileDynamicQuantizer
130-
pass
134+
m = M().eval()
135+
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
136+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
137+
assert tokenizer_path.is_file(), tokenizer_path
138+
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
139+
model_file=str(tokenizer_path)
140+
)
141+
blocksize = 128
142+
percdamp = 0.01
143+
groupsize = 128
144+
calibration_tasks = ["hellaswag"]
145+
calibration_limit = 1000
146+
calibration_seq_length = 100
147+
pad_calibration_inputs = False
148+
quantizer = Int8DynActInt4WeightGPTQQuantizer(
149+
tokenizer,
150+
blocksize,
151+
percdamp,
152+
groupsize,
153+
calibration_tasks,
154+
calibration_limit,
155+
calibration_seq_length,
156+
pad_calibration_inputs,
157+
)
158+
m = quantizer.quantize(m)
131159

132160
if __name__ == "__main__":
133161
unittest.main()

0 commit comments

Comments
 (0)