Skip to content

Commit 088815e

Browse files
authored
removing 8da4w-gptq (#11274)
### Summary This is being deprecated in torchao pytorch/ao#2235 ### Test plan see CI
1 parent 543cdb3 commit 088815e

File tree

1 file changed

+2
-56
lines changed

1 file changed

+2
-56
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 2 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
import re
99
from functools import partial
1010
from pathlib import Path
11-
from typing import Any, Dict, Optional
11+
from typing import Dict, Optional
1212

1313
import torch
1414
import torch.nn as nn
1515
import torch.nn.functional as F
1616

1717
from executorch.extension.llm.export.builder import DType
1818

19-
from sentencepiece import SentencePieceProcessor
20-
2119

2220
try:
2321
from fairseq2.nn.embedding import (
@@ -57,7 +55,7 @@ def quantize( # noqa C901
5755
5856
Args:
5957
model: The model to quantize.
60-
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
58+
qmode: The quantization mode, e.g. int8, 8da4w.
6159
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
6260
Also the dtype of the rest of the non-quantized compoents of the model.
6361
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
@@ -161,58 +159,6 @@ def quantize( # noqa C901
161159
if verbose:
162160
print("quantized model:", model)
163161
return model
164-
elif qmode == "8da4w-gptq":
165-
# Check for required args
166-
required_args: Optional[Any] = [
167-
group_size,
168-
calibration_limit,
169-
calibration_seq_length,
170-
]
171-
if any(arg is None for arg in required_args):
172-
raise Exception(
173-
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
174-
)
175-
if calibration_tasks is None:
176-
calibration_tasks = ["wikitext"]
177-
178-
try:
179-
# torchao 0.3+
180-
from torchao._models._eval import InputRecorder
181-
except ImportError:
182-
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore
183-
184-
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
185-
186-
if tokenizer_path is None:
187-
assert checkpoint_path is not None, "checkpoint_path must be specified"
188-
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
189-
assert tokenizer_path.is_file(), tokenizer_path
190-
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
191-
model_file=str(tokenizer_path)
192-
)
193-
194-
inputs = (
195-
InputRecorder( # pyre-fixme[16]
196-
tokenizer,
197-
calibration_seq_length,
198-
None, # input_prep_func
199-
pad_calibration_inputs,
200-
model.vocab_size,
201-
)
202-
.record_inputs(
203-
calibration_tasks,
204-
calibration_limit,
205-
)
206-
.get_inputs()
207-
)
208-
209-
gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
210-
blocksize,
211-
percdamp,
212-
group_size,
213-
) # TODO: separate computation and checkpoint dtype for GPTQ.
214-
model = gptq_quantizer.quantize(model, inputs)
215-
return model
216162
elif qmode == "vulkan_4w":
217163
from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
218164

0 commit comments

Comments
 (0)