Skip to content

Commit 949bbbe

Browse files
authored
fix seq_cls patcher (#2963)
1 parent 23559d4 commit 949bbbe

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

swift/llm/model/patcher.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,6 @@ def _check_imports(filename) -> List[str]:
8888

8989

9090
def _patch_sequence_classification(model, model_meta):
91-
# rename
92-
idx = model.__class__.__name__.find('For')
93-
if idx != -1:
94-
model.__class__.__name__ = model.__class__.__name__[:idx]
95-
model.__class__.__name__ += 'ForSequenceClassification'
96-
9791
hidden_size = HfConfigFactory.get_config_attr(model.config, 'hidden_size')
9892
initializer_range = HfConfigFactory.get_config_attr(model.config, 'initializer_range')
9993

@@ -115,7 +109,7 @@ def _patch_sequence_classification(model, model_meta):
115109
setattr(llm_model, lm_head, nn.Identity())
116110
break
117111

118-
origin_forward = llm_model.forward
112+
origin_forward = llm_model.forward.__func__
119113

120114
@wraps(origin_forward)
121115
def new_forward(self, *args, **kwargs):
@@ -125,7 +119,7 @@ def new_forward(self, *args, **kwargs):
125119
input_ids = kwargs.get('input_ids')
126120
inputs_embeds = kwargs.get('inputs_embeds')
127121

128-
output = origin_forward(*args, **kwargs)
122+
output = origin_forward(self, *args, **kwargs)
129123
output.logits = output.logits.to(self.score.weight.dtype)
130124
logits = self.score(output.logits)
131125
if input_ids is not None:
@@ -188,19 +182,22 @@ def new_forward(self, *args, **kwargs):
188182

189183
@contextmanager
190184
def patch_automodel_for_sequence_classification(model_meta):
191-
from_pretrained = PreTrainedModel.from_pretrained
185+
from_pretrained = PreTrainedModel.from_pretrained.__func__
192186

193187
@classmethod
194188
def _new_from_pretrained(cls, *args, **kwargs):
189+
cls_name = cls.__name__
190+
cls_name = cls_name.split('For', 1)[0]
191+
cls_name += 'ForSequenceClassification'
192+
cls = type(cls_name, (cls, ), {}) # new_cls
195193
__init__ = cls.__init__
196194

197195
def __new_init__(self, *args, **kwargs):
198196
__init__(self, *args, **kwargs)
199-
if 'SequenceClassification' not in self.__class__.__name__:
200-
_patch_sequence_classification(self, model_meta)
197+
_patch_sequence_classification(self, model_meta)
201198

202199
cls.__init__ = __new_init__
203-
res = from_pretrained.__func__(cls, *args, **kwargs)
200+
res = from_pretrained(cls, *args, **kwargs)
204201
cls.__init__ = __init__
205202
return res
206203

@@ -209,21 +206,21 @@ def __new_init__(self, *args, **kwargs):
209206
try:
210207
yield
211208
finally:
212-
PreTrainedModel.from_pretrained = from_pretrained
209+
PreTrainedModel.from_pretrained = classmethod(from_pretrained)
213210

214211

215212
@contextmanager
216213
def patch_automodel_for_awq():
217-
from_pretrained = PreTrainedModel.from_pretrained
214+
from_pretrained = PreTrainedModel.from_pretrained.__func__
218215

219216
@classmethod
220217
def _new_from_pretrained(cls, *args, **kwargs):
221218
kwargs.pop('use_cache', None)
222-
return from_pretrained.__func__(cls, *args, **kwargs)
219+
return from_pretrained(cls, *args, **kwargs)
223220

224221
PreTrainedModel.from_pretrained = _new_from_pretrained
225222

226223
try:
227224
yield
228225
finally:
229-
PreTrainedModel.from_pretrained = from_pretrained
226+
PreTrainedModel.from_pretrained = classmethod(from_pretrained)

swift/llm/model/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,11 @@ def use_submodel_func(model, submodel_name: str, func_list: Optional[List[str]]
316316
submodel = getattr(model, submodel_name)
317317

318318
def _get_new_func(func_name: str):
319-
_old_func = getattr(submodel, func_name)
319+
_old_func = getattr(submodel, func_name).__func__
320320

321321
@wraps(_old_func)
322322
def _new_func(self, *args, **kwargs):
323-
res = _old_func.__func__(submodel, *args, **kwargs)
323+
res = _old_func(submodel, *args, **kwargs)
324324
if func_name == 'forward':
325325
device = find_device(args)
326326
if device is None:

0 commit comments

Comments
 (0)