Skip to content

Commit 0f69223

Browse files
Initial commit for GPTQModel migration
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 2d1d91d commit 0f69223

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

fms_mo/run_quant.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,28 +124,28 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
124124
"""
125125

126126
# Third Party
127-
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
128-
from auto_gptq.modeling._const import SUPPORTED_MODELS
129-
from auto_gptq.modeling.auto import GPTQ_CAUSAL_LM_MODEL_MAP
127+
from gptqmodel import GPTQModel, QuantizeConfig
128+
from gptqmodel.models._const import SUPPORTED_MODELS
129+
from gptqmodel.models.auto import MODEL_MAP
130130

131131
# Local
132132
from fms_mo.utils.custom_gptq_models import custom_gptq_classes
133133

134134
logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq")
135135

136-
quantize_config = BaseQuantizeConfig(
136+
quantize_config = QuantizeConfig(
137137
bits=gptq_args.bits,
138138
group_size=gptq_args.group_size,
139139
desc_act=gptq_args.desc_act,
140140
damp_percent=gptq_args.damp_percent,
141141
)
142142

143-
# Add custom model_type mapping to auto_gptq LUT so AutoGPTQForCausalLM can recognize them.
143+
# Add custom model_type mapping to auto_gptq LUT so GPTQModel can recognize them.
144144
for mtype, cls in custom_gptq_classes.items():
145145
SUPPORTED_MODELS.append(mtype)
146-
GPTQ_CAUSAL_LM_MODEL_MAP[mtype] = cls
146+
MODEL_MAP[mtype] = cls
147147

148-
model = AutoGPTQForCausalLM.from_pretrained(
148+
model = GPTQModel.from_pretrained(
149149
model_args.model_name_or_path,
150150
quantize_config=quantize_config,
151151
torch_dtype=model_args.torch_dtype,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ dependencies = [
4141
[project.optional-dependencies]
4242
dev = ["pre-commit>=3.0.4,<5.0"]
4343
fp8 = ["llmcompressor"]
44-
gptq = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
44+
gptq = ["gptqmodel"]
4545
visualize = ["matplotlib", "graphviz", "pygraphviz"]
4646
flash-attn = ["flash-attn>=2.5.3,<3.0"]
4747
opt = ["fms-model-optimizer[fp8, gptq]"]

0 commit comments

Comments
 (0)