Skip to content

Commit 6ed3081

Browse files
authored
Fix the slow inference speed bug in qwen AutoGPTQ (#187)
1 parent 4e7eef9 commit 6ed3081

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

swift/llm/utils/model.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def get_model_tokenizer_baichuan2(model_dir: str,
388388
if model is not None:
389389
new_forward = MethodType(patch_baichuan2_lm_head_forward,
390390
model.lm_head)
391-
if hasattr(model, '_old_forward'):
391+
if hasattr(model, '_old_forward'): # device_map
392392
model.lm_head._old_forward = new_forward
393393
else:
394394
model.lm_head.forward = new_forward
@@ -479,12 +479,12 @@ def get_model_tokenizer_chatglm(model_dir: str,
479479
**kwargs)
480480
if model is not None:
481481
from torch.nn import CrossEntropyLoss
482-
_old_forward = CrossEntropyLoss.forward
482+
__old_forward = CrossEntropyLoss.forward
483483

484484
def cross_entropy_forward(self, inputs: Tensor,
485485
target: Tensor) -> Tensor:
486486
target = target.to(device=inputs.device)
487-
return _old_forward(self, inputs, target)
487+
return __old_forward(self, inputs, target)
488488

489489
CrossEntropyLoss.forward = cross_entropy_forward
490490
return model, tokenizer
@@ -768,6 +768,21 @@ def get_model_tokenizer_qwen_chat(*args, **kwargs):
768768
return model, tokenizer
769769

770770

771+
def fix_qwen_inplace_bug(model) -> None:
772+
first_drop = model.transformer.drop
773+
if first_drop.p == 0.:
774+
# fix in-place operation bug
775+
__old_forward = first_drop.forward
776+
if not hasattr(first_drop, '__old_forward'): # Avoid double patching
777+
first_drop.__old_forward = __old_forward
778+
if hasattr(first_drop, '_old_forward'): # device_map
779+
first_drop._old_forward = lambda *args, **kwargs: __old_forward(
780+
*args, **kwargs).clone()
781+
else:
782+
first_drop.forwad = lambda *args, **kwargs: __old_forward(
783+
*args, **kwargs).clone()
784+
785+
771786
@register_model(
772787
ModelType.qwen_vl_chat,
773788
'qwen/Qwen-VL-Chat',
@@ -797,14 +812,7 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
797812
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
798813
load_model, **kwargs)
799814
if model is not None:
800-
first_drop = model.transformer.drop
801-
if first_drop.p == 0.:
802-
# fix in-place operation bug
803-
_old_forward = first_drop.forward
804-
if not hasattr(_old_forward, '_patching'):
805-
first_drop.forward = lambda *args, **kwargs: _old_forward(
806-
*args, **kwargs).clone()
807-
first_drop.forward._patching = True
815+
fix_qwen_inplace_bug(model)
808816

809817
_old_decode = tokenizer._decode
810818

@@ -817,9 +825,9 @@ def _new_decode(*args, skip_special_tokens=False, **kwargs) -> str:
817825
else:
818826
return _old_decode(*args, skip_special_tokens=False, **kwargs)
819827

820-
if not hasattr(_old_decode, '_patching'):
828+
if not hasattr(tokenizer, '_old_decode'): # avoid double patching
829+
tokenizer._old_decode = _old_decode
821830
tokenizer._decode = _new_decode
822-
tokenizer._decode._patching = True
823831

824832
return model, tokenizer
825833

@@ -847,14 +855,7 @@ def get_model_tokenizer_qwen_audio(model_dir: str,
847855
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
848856
load_model, **kwargs)
849857
if model is not None:
850-
first_drop = model.transformer.drop
851-
if first_drop.p == 0.:
852-
# fix in-place operation bug
853-
_old_forward = first_drop.forward
854-
if not hasattr(_old_forward, '_patching'):
855-
first_drop.forward = lambda *args, **kwargs: _old_forward(
856-
*args, **kwargs).clone()
857-
first_drop.forward._patching = True
858+
fix_qwen_inplace_bug(model)
858859
return model, tokenizer
859860

860861

@@ -968,11 +969,20 @@ def get_model_tokenizer_qwen_intx(model_dir: str,
968969

969970
# fix quantlinear bug
970971
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear
971-
_old_qlinear_init = QuantLinear.__init__
972-
if not hasattr(_old_qlinear_init, '_patching'):
973-
QuantLinear.__init__ = (lambda *args, **kwargs: _old_qlinear_init(
974-
*args, kernel_switch_threshold=1, **kwargs))
975-
QuantLinear.__init__._patching = True
972+
__old_forward = QuantLinear.forward
973+
974+
def _new_forward(self, x):
975+
if not self.training or not self.autogptq_cuda_available:
976+
return self.__old_forward(x)
977+
# fix sft no grad
978+
self.autogptq_cuda_available = False
979+
res = self.__old_forward(x)
980+
self.autogptq_cuda_available = True
981+
return res
982+
983+
if not hasattr(QuantLinear, '__old_forward'): # avoid double patching
984+
QuantLinear.__old_forward = __old_forward
985+
QuantLinear.forward = _new_forward
976986
get_qwen_function = kwargs.pop('get_qwen_function',
977987
get_model_tokenizer_qwen_chat)
978988
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
@@ -1035,13 +1045,15 @@ def fix_transformers_upgrade(module: PreTrainedModel) -> None:
10351045
def fix_gradient_checkpointing_warning() -> None:
10361046
if version.parse(torch.__version__) < version.parse('2'):
10371047
return
1038-
_old_forward = torch.utils.checkpoint.checkpoint
1039-
if getattr(_old_forward, '_patching', False) is False:
1040-
_old_forward._patching = True
1048+
_old_checkpoint = torch.utils.checkpoint.checkpoint
1049+
if not hasattr(torch.utils.checkpoint,
1050+
'_old_checkpoint'): # avoid double patching
1051+
1052+
torch.utils.checkpoint._old_checkpoint = _old_checkpoint
10411053
torch.utils.checkpoint.checkpoint = update_wrapper(
1042-
lambda *args, use_reentrant=False, **kwargs: _old_forward(
1054+
lambda *args, use_reentrant=False, **kwargs: _old_checkpoint(
10431055
*args, use_reentrant=use_reentrant, **kwargs),
1044-
_old_forward)
1056+
_old_checkpoint)
10451057

10461058

10471059
def get_model_tokenizer(

0 commit comments

Comments
 (0)