|
8 | 8 | import re
|
9 | 9 | from functools import partial
|
10 | 10 | from pathlib import Path
|
11 |
| -from typing import Any, Dict, Optional |
| 11 | +from typing import Dict, Optional |
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | import torch.nn as nn
|
15 | 15 | import torch.nn.functional as F
|
16 | 16 |
|
17 | 17 | from executorch.extension.llm.export.builder import DType
|
18 | 18 |
|
19 |
| -from sentencepiece import SentencePieceProcessor |
20 |
| - |
21 | 19 |
|
22 | 20 | try:
|
23 | 21 | from fairseq2.nn.embedding import (
|
@@ -57,7 +55,7 @@ def quantize( # noqa C901
|
57 | 55 |
|
58 | 56 | Args:
|
59 | 57 | model: The model to quantize.
|
60 |
| - qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq. |
| 58 | + qmode: The quantization mode, e.g. int8, 8da4w. |
61 | 59 | computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
|
62 | 60 | Also the dtype of the rest of the non-quantized compoents of the model.
|
63 | 61 | checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
|
@@ -161,58 +159,6 @@ def quantize( # noqa C901
|
161 | 159 | if verbose:
|
162 | 160 | print("quantized model:", model)
|
163 | 161 | return model
|
164 |
| - elif qmode == "8da4w-gptq": |
165 |
| - # Check for required args |
166 |
| - required_args: Optional[Any] = [ |
167 |
| - group_size, |
168 |
| - calibration_limit, |
169 |
| - calibration_seq_length, |
170 |
| - ] |
171 |
| - if any(arg is None for arg in required_args): |
172 |
| - raise Exception( |
173 |
| - "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified." |
174 |
| - ) |
175 |
| - if calibration_tasks is None: |
176 |
| - calibration_tasks = ["wikitext"] |
177 |
| - |
178 |
| - try: |
179 |
| - # torchao 0.3+ |
180 |
| - from torchao._models._eval import InputRecorder |
181 |
| - except ImportError: |
182 |
| - from torchao.quantization.GPTQ import InputRecorder # pyre-ignore |
183 |
| - |
184 |
| - from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer |
185 |
| - |
186 |
| - if tokenizer_path is None: |
187 |
| - assert checkpoint_path is not None, "checkpoint_path must be specified" |
188 |
| - tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
189 |
| - assert tokenizer_path.is_file(), tokenizer_path |
190 |
| - tokenizer = SentencePieceProcessor( # pyre-ignore[28] |
191 |
| - model_file=str(tokenizer_path) |
192 |
| - ) |
193 |
| - |
194 |
| - inputs = ( |
195 |
| - InputRecorder( # pyre-fixme[16] |
196 |
| - tokenizer, |
197 |
| - calibration_seq_length, |
198 |
| - None, # input_prep_func |
199 |
| - pad_calibration_inputs, |
200 |
| - model.vocab_size, |
201 |
| - ) |
202 |
| - .record_inputs( |
203 |
| - calibration_tasks, |
204 |
| - calibration_limit, |
205 |
| - ) |
206 |
| - .get_inputs() |
207 |
| - ) |
208 |
| - |
209 |
| - gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer( |
210 |
| - blocksize, |
211 |
| - percdamp, |
212 |
| - group_size, |
213 |
| - ) # TODO: separate computation and checkpoint dtype for GPTQ. |
214 |
| - model = gptq_quantizer.quantize(model, inputs) |
215 |
| - return model |
216 | 162 | elif qmode == "vulkan_4w":
|
217 | 163 | from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
|
218 | 164 |
|
|
0 commit comments