|
8 | 8 | import re |
9 | 9 | from functools import partial |
10 | 10 | from pathlib import Path |
11 | | -from typing import Any, Dict, Optional |
| 11 | +from typing import Dict, Optional |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | import torch.nn as nn |
15 | 15 | import torch.nn.functional as F |
16 | 16 |
|
17 | 17 | from executorch.extension.llm.export.builder import DType |
18 | 18 |
|
19 | | -from sentencepiece import SentencePieceProcessor |
20 | | - |
21 | 19 |
|
22 | 20 | try: |
23 | 21 | from fairseq2.nn.embedding import ( |
@@ -57,7 +55,7 @@ def quantize( # noqa C901 |
57 | 55 |
|
58 | 56 | Args: |
59 | 57 | model: The model to quantize. |
60 | | - qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq. |
| 58 | + qmode: The quantization mode, e.g. int8, 8da4w. |
61 | 59 | computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization). |
62 | 60 | Also the dtype of the rest of the non-quantized compoents of the model. |
63 | 61 | checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to |
@@ -161,58 +159,6 @@ def quantize( # noqa C901 |
161 | 159 | if verbose: |
162 | 160 | print("quantized model:", model) |
163 | 161 | return model |
164 | | - elif qmode == "8da4w-gptq": |
165 | | - # Check for required args |
166 | | - required_args: Optional[Any] = [ |
167 | | - group_size, |
168 | | - calibration_limit, |
169 | | - calibration_seq_length, |
170 | | - ] |
171 | | - if any(arg is None for arg in required_args): |
172 | | - raise Exception( |
173 | | - "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified." |
174 | | - ) |
175 | | - if calibration_tasks is None: |
176 | | - calibration_tasks = ["wikitext"] |
177 | | - |
178 | | - try: |
179 | | - # torchao 0.3+ |
180 | | - from torchao._models._eval import InputRecorder |
181 | | - except ImportError: |
182 | | - from torchao.quantization.GPTQ import InputRecorder # pyre-ignore |
183 | | - |
184 | | - from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer |
185 | | - |
186 | | - if tokenizer_path is None: |
187 | | - assert checkpoint_path is not None, "checkpoint_path must be specified" |
188 | | - tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
189 | | - assert tokenizer_path.is_file(), tokenizer_path |
190 | | - tokenizer = SentencePieceProcessor( # pyre-ignore[28] |
191 | | - model_file=str(tokenizer_path) |
192 | | - ) |
193 | | - |
194 | | - inputs = ( |
195 | | - InputRecorder( # pyre-fixme[16] |
196 | | - tokenizer, |
197 | | - calibration_seq_length, |
198 | | - None, # input_prep_func |
199 | | - pad_calibration_inputs, |
200 | | - model.vocab_size, |
201 | | - ) |
202 | | - .record_inputs( |
203 | | - calibration_tasks, |
204 | | - calibration_limit, |
205 | | - ) |
206 | | - .get_inputs() |
207 | | - ) |
208 | | - |
209 | | - gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer( |
210 | | - blocksize, |
211 | | - percdamp, |
212 | | - group_size, |
213 | | - ) # TODO: separate computation and checkpoint dtype for GPTQ. |
214 | | - model = gptq_quantizer.quantize(model, inputs) |
215 | | - return model |
216 | 162 | elif qmode == "vulkan_4w": |
217 | 163 | from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer |
218 | 164 |
|
|
0 commit comments