Skip to content

Commit 2e1e58d

Browse files
Enable gptq quantization through quantize API
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 0f69223 commit 2e1e58d

File tree

5 files changed

+15
-14
lines changed

5 files changed

+15
-14
lines changed

fms_mo/run_quant.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def quantize(
8585
logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")
8686

8787
if opt_args.quant_method == "gptq":
88-
if not available_packages["auto_gptq"]:
88+
if not available_packages["gptqmodel"]:
8989
raise ImportError(
9090
"Quantization method has been selected as gptq but unable to use external library, "
91-
"auto_gptq module not found. For more instructions on installing the appropriate "
92-
"package, see https://github.com/AutoGPTQ/AutoGPTQ?tab=readme-ov-file#installation"
91+
"gptqmodel module not found. For more instructions on installing the appropriate "
92+
"package, see https://github.com/ModelCloud/GPTQModel/tree/main?tab=readme-ov-file#install"
9393
)
9494
run_gptq(model_args, data_args, opt_args, gptq_args)
9595
elif opt_args.quant_method == "fp8":
@@ -127,6 +127,7 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
127127
from gptqmodel import GPTQModel, QuantizeConfig
128128
from gptqmodel.models._const import SUPPORTED_MODELS
129129
from gptqmodel.models.auto import MODEL_MAP
130+
from gptqmodel.utils.backend import BACKEND
130131

131132
# Local
132133
from fms_mo.utils.custom_gptq_models import custom_gptq_classes
@@ -164,17 +165,17 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
164165
start_time = time.time()
165166
model.quantize(
166167
data,
167-
use_triton=gptq_args.use_triton,
168+
backend=BACKEND.TRITON if gptq_args.use_triton else BACKEND.AUTO,
168169
batch_size=gptq_args.batch_size,
169-
cache_examples_on_gpu=gptq_args.cache_examples_on_gpu,
170+
calibration_enable_gpu_cache=gptq_args.cache_examples_on_gpu,
170171
)
171172

172173
logger.info(
173174
f"Time to quantize model at {opt_args.output_dir}: {time.time() - start_time}"
174175
)
175176

176177
logger.info(f"Saving quantized model and tokenizer to {opt_args.output_dir}")
177-
model.save_quantized(opt_args.output_dir, use_safetensors=True)
178+
model.save_quantized(opt_args.output_dir)
178179
tokenizer.save_pretrained(opt_args.output_dir)
179180

180181

fms_mo/utils/custom_gptq_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
"""Allow users to add new GPTQ classes for their custom models easily."""
1616

1717
# Third Party
18-
from auto_gptq.modeling import BaseGPTQForCausalLM
18+
from gptqmodel.models.base import BaseGPTQModel
1919

2020

21-
class GraniteGPTQForCausalLM(BaseGPTQForCausalLM):
21+
class GraniteGPTQForCausalLM(BaseGPTQModel):
2222
"""Enable Granite for GPTQ."""
2323

2424
layer_type = "GraniteDecoderLayer"
@@ -32,7 +32,7 @@ class GraniteGPTQForCausalLM(BaseGPTQForCausalLM):
3232
]
3333

3434

35-
class GraniteMoeGPTQForCausalLM(BaseGPTQForCausalLM):
35+
class GraniteMoeGPTQForCausalLM(BaseGPTQModel):
3636
"""Enable Granite MOE for GPTQ."""
3737

3838
layer_type = "GraniteMoeDecoderLayer"

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222

2323
optional_packages = [
24-
"auto_gptq",
24+
"gptqmodel",
2525
"exllama_kernels",
2626
"exllamav2_kernels",
2727
"llmcompressor",

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers=[
2323
dynamic = ["version"]
2424
dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
26-
"accelerate>=0.20.3,!=0.34,<1.1",
26+
"accelerate>=1.2.1,!=0.34",
2727
"transformers>=4.45,<4.48",
2828
"torch>=2.2.0,<2.4",
2929
"tqdm>=4.66.2,<5.0",
@@ -41,7 +41,7 @@ dependencies = [
4141
[project.optional-dependencies]
4242
dev = ["pre-commit>=3.0.4,<5.0"]
4343
fp8 = ["llmcompressor"]
44-
gptq = ["gptqmodel"]
44+
gptq = ["Cython", "gptqmodel>=1.7.3"]
4545
visualize = ["matplotlib", "graphviz", "pygraphviz"]
4646
flash-attn = ["flash-attn>=2.5.3,<3.0"]
4747
opt = ["fms-model-optimizer[fp8, gptq]"]

tests/build/test_launch_script.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def cleanup_env():
8686

8787

8888
@pytest.mark.skipif(
89-
not available_packages["auto_gptq"],
89+
not available_packages["gptqmodel"],
9090
reason="Only runs if auto-gptq package is installed",
9191
)
9292
def test_successful_gptq():
@@ -254,7 +254,7 @@ def _validate_quantization_output(base_dir, quant_method):
254254

255255
# Check quantized model files exist
256256
if quant_method == "gptq":
257-
assert len(glob.glob(os.path.join(base_dir, "gptq_model-*.safetensors"))) > 0
257+
assert len(glob.glob(os.path.join(base_dir, "model*.safetensors"))) > 0
258258
assert os.path.exists(os.path.join(base_dir, "quantize_config.json")) is True
259259
assert os.path.exists(os.path.join(base_dir, "config.json")) is True
260260

0 commit comments

Comments
 (0)