@@ -388,7 +388,7 @@ def get_model_tokenizer_baichuan2(model_dir: str,
388388 if model is not None :
389389 new_forward = MethodType (patch_baichuan2_lm_head_forward ,
390390 model .lm_head )
391- if hasattr (model , '_old_forward' ):
391+ if hasattr (model , '_old_forward' ): # device_map
392392 model .lm_head ._old_forward = new_forward
393393 else :
394394 model .lm_head .forward = new_forward
@@ -479,12 +479,12 @@ def get_model_tokenizer_chatglm(model_dir: str,
479479 ** kwargs )
480480 if model is not None :
481481 from torch .nn import CrossEntropyLoss
482- _old_forward = CrossEntropyLoss .forward
482+ __old_forward = CrossEntropyLoss .forward
483483
484484 def cross_entropy_forward (self , inputs : Tensor ,
485485 target : Tensor ) -> Tensor :
486486 target = target .to (device = inputs .device )
487- return _old_forward (self , inputs , target )
487+ return __old_forward (self , inputs , target )
488488
489489 CrossEntropyLoss .forward = cross_entropy_forward
490490 return model , tokenizer
@@ -768,6 +768,21 @@ def get_model_tokenizer_qwen_chat(*args, **kwargs):
768768 return model , tokenizer
769769
770770
771+ def fix_qwen_inplace_bug (model ) -> None :
772+ first_drop = model .transformer .drop
773+ if first_drop .p == 0. :
774+ # fix in-place operation bug
775+ __old_forward = first_drop .forward
776+ if not hasattr (first_drop , '__old_forward' ): # Avoid double patching
777+ first_drop .__old_forward = __old_forward
778+ if hasattr (first_drop , '_old_forward' ): # device_map
779+ first_drop ._old_forward = lambda * args , ** kwargs : __old_forward (
780+ * args , ** kwargs ).clone ()
781+ else :
782+ first_drop .forwad = lambda * args , ** kwargs : __old_forward (
783+ * args , ** kwargs ).clone ()
784+
785+
771786@register_model (
772787 ModelType .qwen_vl_chat ,
773788 'qwen/Qwen-VL-Chat' ,
@@ -797,14 +812,7 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
797812 model , tokenizer = get_qwen_function (model_dir , torch_dtype , model_kwargs ,
798813 load_model , ** kwargs )
799814 if model is not None :
800- first_drop = model .transformer .drop
801- if first_drop .p == 0. :
802- # fix in-place operation bug
803- _old_forward = first_drop .forward
804- if not hasattr (_old_forward , '_patching' ):
805- first_drop .forward = lambda * args , ** kwargs : _old_forward (
806- * args , ** kwargs ).clone ()
807- first_drop .forward ._patching = True
815+ fix_qwen_inplace_bug (model )
808816
809817 _old_decode = tokenizer ._decode
810818
@@ -817,9 +825,9 @@ def _new_decode(*args, skip_special_tokens=False, **kwargs) -> str:
817825 else :
818826 return _old_decode (* args , skip_special_tokens = False , ** kwargs )
819827
820- if not hasattr (_old_decode , '_patching' ):
828+ if not hasattr (tokenizer , '_old_decode' ): # avoid double patching
829+ tokenizer ._old_decode = _old_decode
821830 tokenizer ._decode = _new_decode
822- tokenizer ._decode ._patching = True
823831
824832 return model , tokenizer
825833
@@ -847,14 +855,7 @@ def get_model_tokenizer_qwen_audio(model_dir: str,
847855 model , tokenizer = get_qwen_function (model_dir , torch_dtype , model_kwargs ,
848856 load_model , ** kwargs )
849857 if model is not None :
850- first_drop = model .transformer .drop
851- if first_drop .p == 0. :
852- # fix in-place operation bug
853- _old_forward = first_drop .forward
854- if not hasattr (_old_forward , '_patching' ):
855- first_drop .forward = lambda * args , ** kwargs : _old_forward (
856- * args , ** kwargs ).clone ()
857- first_drop .forward ._patching = True
858+ fix_qwen_inplace_bug (model )
858859 return model , tokenizer
859860
860861
@@ -968,11 +969,20 @@ def get_model_tokenizer_qwen_intx(model_dir: str,
968969
969970 # fix quantlinear bug
970971 from auto_gptq .nn_modules .qlinear .qlinear_cuda_old import QuantLinear
971- _old_qlinear_init = QuantLinear .__init__
972- if not hasattr (_old_qlinear_init , '_patching' ):
973- QuantLinear .__init__ = (lambda * args , ** kwargs : _old_qlinear_init (
974- * args , kernel_switch_threshold = 1 , ** kwargs ))
975- QuantLinear .__init__ ._patching = True
972+ __old_forward = QuantLinear .forward
973+
974+ def _new_forward (self , x ):
975+ if not self .training or not self .autogptq_cuda_available :
976+ return self .__old_forward (x )
977+ # fix sft no grad
978+ self .autogptq_cuda_available = False
979+ res = self .__old_forward (x )
980+ self .autogptq_cuda_available = True
981+ return res
982+
983+ if not hasattr (QuantLinear , '__old_forward' ): # avoid double patching
984+ QuantLinear .__old_forward = __old_forward
985+ QuantLinear .forward = _new_forward
976986 get_qwen_function = kwargs .pop ('get_qwen_function' ,
977987 get_model_tokenizer_qwen_chat )
978988 model , tokenizer = get_qwen_function (model_dir , torch_dtype , model_kwargs ,
@@ -1035,13 +1045,15 @@ def fix_transformers_upgrade(module: PreTrainedModel) -> None:
10351045def fix_gradient_checkpointing_warning () -> None :
10361046 if version .parse (torch .__version__ ) < version .parse ('2' ):
10371047 return
1038- _old_forward = torch .utils .checkpoint .checkpoint
1039- if getattr (_old_forward , '_patching' , False ) is False :
1040- _old_forward ._patching = True
1048+ _old_checkpoint = torch .utils .checkpoint .checkpoint
1049+ if not hasattr (torch .utils .checkpoint ,
1050+ '_old_checkpoint' ): # avoid double patching
1051+
1052+ torch .utils .checkpoint ._old_checkpoint = _old_checkpoint
10411053 torch .utils .checkpoint .checkpoint = update_wrapper (
1042- lambda * args , use_reentrant = False , ** kwargs : _old_forward (
1054+ lambda * args , use_reentrant = False , ** kwargs : _old_checkpoint (
10431055 * args , use_reentrant = use_reentrant , ** kwargs ),
1044- _old_forward )
1056+ _old_checkpoint )
10451057
10461058
10471059def get_model_tokenizer (
0 commit comments