@@ -124,28 +124,28 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
124124 """
125125
126126 # Third Party
127- from auto_gptq import AutoGPTQForCausalLM , BaseQuantizeConfig
128- from auto_gptq . modeling ._const import SUPPORTED_MODELS
129- from auto_gptq . modeling .auto import GPTQ_CAUSAL_LM_MODEL_MAP
127+ from gptqmodel import GPTQModel , QuantizeConfig
128+ from gptqmodel . models ._const import SUPPORTED_MODELS
129+ from gptqmodel . models .auto import MODEL_MAP
130130
131131 # Local
132132 from fms_mo .utils .custom_gptq_models import custom_gptq_classes
133133
134134 logger = set_log_level (opt_args .log_level , "fms_mo.run_gptq" )
135135
136- quantize_config = BaseQuantizeConfig (
136+ quantize_config = QuantizeConfig (
137137 bits = gptq_args .bits ,
138138 group_size = gptq_args .group_size ,
139139 desc_act = gptq_args .desc_act ,
140140 damp_percent = gptq_args .damp_percent ,
141141 )
142142
143- # Add custom model_type mapping to auto_gptq LUT so AutoGPTQForCausalLM can recognize them.
143+ # Add custom model_type mapping to auto_gptq LUT so GPTQModel can recognize them.
144144 for mtype , cls in custom_gptq_classes .items ():
145145 SUPPORTED_MODELS .append (mtype )
146- GPTQ_CAUSAL_LM_MODEL_MAP [mtype ] = cls
146+ MODEL_MAP [mtype ] = cls
147147
148- model = AutoGPTQForCausalLM .from_pretrained (
148+ model = GPTQModel .from_pretrained (
149149 model_args .model_name_or_path ,
150150 quantize_config = quantize_config ,
151151 torch_dtype = model_args .torch_dtype ,
0 commit comments