diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index d2e2d5396d3..19fef857865 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -8,7 +8,7 @@ import re from functools import partial from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -16,8 +16,6 @@ from executorch.extension.llm.export.builder import DType -from sentencepiece import SentencePieceProcessor - try: from fairseq2.nn.embedding import ( @@ -57,7 +55,7 @@ def quantize( # noqa C901 Args: model: The model to quantize. - qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq. + qmode: The quantization mode, e.g. int8, 8da4w. computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization). Also the dtype of the rest of the non-quantized compoents of the model. checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to @@ -161,58 +159,6 @@ def quantize( # noqa C901 if verbose: print("quantized model:", model) return model - elif qmode == "8da4w-gptq": - # Check for required args - required_args: Optional[Any] = [ - group_size, - calibration_limit, - calibration_seq_length, - ] - if any(arg is None for arg in required_args): - raise Exception( - "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified." - ) - if calibration_tasks is None: - calibration_tasks = ["wikitext"] - - try: - # torchao 0.3+ - from torchao._models._eval import InputRecorder - except ImportError: - from torchao.quantization.GPTQ import InputRecorder # pyre-ignore - - from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer - - if tokenizer_path is None: - assert checkpoint_path is not None, "checkpoint_path must be specified" - tokenizer_path = checkpoint_path.parent / "tokenizer.model" - assert tokenizer_path.is_file(), tokenizer_path - tokenizer = SentencePieceProcessor( # pyre-ignore[28] - model_file=str(tokenizer_path) - ) - - inputs = ( - InputRecorder( # pyre-fixme[16] - tokenizer, - calibration_seq_length, - None, # input_prep_func - pad_calibration_inputs, - model.vocab_size, - ) - .record_inputs( - calibration_tasks, - calibration_limit, - ) - .get_inputs() - ) - - gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer( - blocksize, - percdamp, - group_size, - ) # TODO: separate computation and checkpoint dtype for GPTQ. - model = gptq_quantizer.quantize(model, inputs) - return model elif qmode == "vulkan_4w": from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer