|
22 | 22 | from torchao.quantization.quant_api import (
|
23 | 23 | Quantizer,
|
24 | 24 | TwoStepQuantizer,
|
| 25 | + Int8DynActInt4WeightGPTQQuantizer, |
25 | 26 | )
|
| 27 | +from pathlib import Path |
| 28 | +from sentencepiece import SentencePieceProcessor |
| 29 | + |
26 | 30 |
|
27 | 31 | def dynamic_quant(model, example_inputs):
|
28 | 32 | m = capture_pre_autograd_graph(model, example_inputs)
|
@@ -127,7 +131,31 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
|
127 | 131 |
|
128 | 132 | def test_gptq(self):
|
129 | 133 | # should be similar to TorchCompileDynamicQuantizer
|
130 |
| - pass |
| 134 | + m = M().eval() |
| 135 | + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") |
| 136 | + tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
| 137 | + assert tokenizer_path.is_file(), tokenizer_path |
| 138 | + tokenizer = SentencePieceProcessor( # pyre-ignore[28] |
| 139 | + model_file=str(tokenizer_path) |
| 140 | + ) |
| 141 | + blocksize = 128 |
| 142 | + percdamp = 0.01 |
| 143 | + groupsize = 128 |
| 144 | + calibration_tasks = ["hellaswag"] |
| 145 | + calibration_limit = 1000 |
| 146 | + calibration_seq_length = 100 |
| 147 | + pad_calibration_inputs = False |
| 148 | + quantizer = Int8DynActInt4WeightGPTQQuantizer( |
| 149 | + tokenizer, |
| 150 | + blocksize, |
| 151 | + percdamp, |
| 152 | + groupsize, |
| 153 | + calibration_tasks, |
| 154 | + calibration_limit, |
| 155 | + calibration_seq_length, |
| 156 | + pad_calibration_inputs, |
| 157 | + ) |
| 158 | + m = quantizer.quantize(m) |
131 | 159 |
|
132 | 160 | if __name__ == "__main__":
|
133 | 161 | unittest.main()
|
0 commit comments