Skip to content

Commit ff8b797

Browse files
tastelikefeettastelikefeet
andauthored
fix multi-modal padding_free for seq_cls (#6087)
Co-authored-by: tastelikefeet <[email protected]>
1 parent 3e61395 commit ff8b797

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

swift/trainers/mixin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,13 +655,14 @@ def revert_padding_free_hook(module, args, input, output):
655655
elif self.args.task_type == 'seq_cls':
656656
llm_model = get_llm_model(self.model, model_meta=self.model.model_meta)
657657

658-
def seq_cls_forward(model, **kwargs):
658+
@wraps(model.forward.__func__)
659+
def seq_cls_forward(model, *args, **kwargs):
659660

660-
def inner_forward(**kwargs):
661-
output = llm_model.forward(**kwargs)
661+
def inner_forward(*args, **kwargs):
662+
output = llm_model.forward(*args, **kwargs)
662663
return revert_padding_free(output, kwargs, self.args.padding_side)
663664

664-
return transformers_seq_cls_forward(model, origin_forward=inner_forward, **kwargs)
665+
return transformers_seq_cls_forward(model, *args, origin_forward=inner_forward, **kwargs)
665666

666667
model.forward = MethodType(seq_cls_forward, model)
667668
elif self.args.task_type == 'reranker':

0 commit comments

Comments
 (0)