@@ -88,12 +88,6 @@ def _check_imports(filename) -> List[str]:
8888
8989
9090def _patch_sequence_classification (model , model_meta ):
91- # rename
92- idx = model .__class__ .__name__ .find ('For' )
93- if idx != - 1 :
94- model .__class__ .__name__ = model .__class__ .__name__ [:idx ]
95- model .__class__ .__name__ += 'ForSequenceClassification'
96-
9791 hidden_size = HfConfigFactory .get_config_attr (model .config , 'hidden_size' )
9892 initializer_range = HfConfigFactory .get_config_attr (model .config , 'initializer_range' )
9993
@@ -115,7 +109,7 @@ def _patch_sequence_classification(model, model_meta):
115109 setattr (llm_model , lm_head , nn .Identity ())
116110 break
117111
118- origin_forward = llm_model .forward
112+ origin_forward = llm_model .forward . __func__
119113
120114 @wraps (origin_forward )
121115 def new_forward (self , * args , ** kwargs ):
@@ -125,7 +119,7 @@ def new_forward(self, *args, **kwargs):
125119 input_ids = kwargs .get ('input_ids' )
126120 inputs_embeds = kwargs .get ('inputs_embeds' )
127121
128- output = origin_forward (* args , ** kwargs )
122+ output = origin_forward (self , * args , ** kwargs )
129123 output .logits = output .logits .to (self .score .weight .dtype )
130124 logits = self .score (output .logits )
131125 if input_ids is not None :
@@ -188,19 +182,22 @@ def new_forward(self, *args, **kwargs):
188182
189183@contextmanager
190184def patch_automodel_for_sequence_classification (model_meta ):
191- from_pretrained = PreTrainedModel .from_pretrained
185+ from_pretrained = PreTrainedModel .from_pretrained . __func__
192186
193187 @classmethod
194188 def _new_from_pretrained (cls , * args , ** kwargs ):
189+ cls_name = cls .__name__
190+ cls_name = cls_name .split ('For' , 1 )[0 ]
191+ cls_name += 'ForSequenceClassification'
192+ cls = type (cls_name , (cls , ), {}) # new_cls
195193 __init__ = cls .__init__
196194
197195 def __new_init__ (self , * args , ** kwargs ):
198196 __init__ (self , * args , ** kwargs )
199- if 'SequenceClassification' not in self .__class__ .__name__ :
200- _patch_sequence_classification (self , model_meta )
197+ _patch_sequence_classification (self , model_meta )
201198
202199 cls .__init__ = __new_init__
203- res = from_pretrained . __func__ (cls , * args , ** kwargs )
200+ res = from_pretrained (cls , * args , ** kwargs )
204201 cls .__init__ = __init__
205202 return res
206203
@@ -209,21 +206,21 @@ def __new_init__(self, *args, **kwargs):
209206 try :
210207 yield
211208 finally :
212- PreTrainedModel .from_pretrained = from_pretrained
209+ PreTrainedModel .from_pretrained = classmethod ( from_pretrained )
213210
214211
215212@contextmanager
216213def patch_automodel_for_awq ():
217- from_pretrained = PreTrainedModel .from_pretrained
214+ from_pretrained = PreTrainedModel .from_pretrained . __func__
218215
219216 @classmethod
220217 def _new_from_pretrained (cls , * args , ** kwargs ):
221218 kwargs .pop ('use_cache' , None )
222- return from_pretrained . __func__ (cls , * args , ** kwargs )
219+ return from_pretrained (cls , * args , ** kwargs )
223220
224221 PreTrainedModel .from_pretrained = _new_from_pretrained
225222
226223 try :
227224 yield
228225 finally :
229- PreTrainedModel .from_pretrained = from_pretrained
226+ PreTrainedModel .from_pretrained = classmethod ( from_pretrained )
0 commit comments