Skip to content

Commit b2e5d6e

Browse files
authored
fix awq quant device_map (#2488)
1 parent cf574d9 commit b2e5d6e

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

swift/llm/export.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ def _get_dataset(*args, **kwargs):
8181
return res
8282

8383

84+
@contextmanager
85+
def _patch_move_embed(awq_model):
86+
_origin_move_embed = awq_model.move_embed
87+
88+
def _move_embed(model, device: str):
89+
if hasattr(model, '_hf_hook') and device != 'cpu':
90+
return
91+
_origin_move_embed(model, device)
92+
93+
awq_model.move_embed = _move_embed
94+
try:
95+
yield
96+
finally:
97+
awq_model.move_embed = _origin_move_embed
98+
99+
84100
def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
85101

86102
from awq.quantize import quantizer
@@ -93,7 +109,8 @@ def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
93109
group_size = 128
94110
quant_config = {'zero_point': True, 'q_group_size': group_size, 'w_bit': _args.quant_bits, 'version': 'GEMM'}
95111
logger.info('Start quantizing the model...')
96-
awq_model.quantize(tokenizer, quant_config=quant_config, n_parallel_calib_samples=batch_size)
112+
with _patch_move_embed(awq_model):
113+
awq_model.quantize(tokenizer, quant_config=quant_config, n_parallel_calib_samples=batch_size)
97114
quantizer.get_calib_dataset = _origin_get_calib_dataset # recover
98115
awq_model.model.config.quantization_config = AwqConfig(
99116
bits=_args.quant_bits, group_size=group_size, zero_point=True, version='GEMM')
@@ -260,6 +277,7 @@ def llm_export(args: ExportArguments) -> None:
260277
from awq import AutoAWQForCausalLM
261278
model, template = prepare_model_template(
262279
args, device_map=args.quant_device_map, task='export', automodel_class=AutoAWQForCausalLM)
280+
template.model = model.model
263281
awq_model_quantize(model, template.tokenizer, args.quant_batch_size)
264282
model.save_quantized(args.quant_output_dir)
265283
elif args.quant_method == 'gptq':

0 commit comments

Comments
 (0)