Skip to content

Commit 5de484f

Browse files
authored
fix (#6088)
1 parent ff8b797 commit 5de484f

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

swift/trainers/mixin.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,16 @@ def inner_forward(*args, **kwargs):
668668
elif self.args.task_type == 'reranker':
669669
llm_model = get_llm_model(self.model, model_meta=self.model.model_meta)
670670

671-
def revert_padding_free_hook(module, args, input, output):
672-
return revert_padding_free(output, input, self.args.padding_side)
671+
@wraps(model.forward.__func__)
672+
def reranker_forward(model, *args, **kwargs):
673673

674-
llm_model.register_forward_hook(revert_padding_free_hook, with_kwargs=True, prepend=True)
674+
def inner_forward(*args, **kwargs):
675+
output = llm_model.forward(*args, **kwargs)
676+
return revert_padding_free(output, kwargs, self.args.padding_side)
677+
678+
return transformers_seq_cls_forward(model, *args, origin_forward=inner_forward, **kwargs)
679+
680+
model.forward = MethodType(reranker_forward, model)
675681
elif self.args.task_type == 'generative_reranker':
676682
llm_model = get_llm_model(self.model, model_meta=self.model.model_meta)
677683

0 commit comments

Comments
 (0)