File tree Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Expand file tree Collapse file tree 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change @@ -668,10 +668,16 @@ def inner_forward(*args, **kwargs):
668668 elif self .args .task_type == 'reranker' :
669669 llm_model = get_llm_model (self .model , model_meta = self .model .model_meta )
670670
671- def revert_padding_free_hook ( module , args , input , output ):
672- return revert_padding_free ( output , input , self . args . padding_side )
671+ @ wraps ( model . forward . __func__ )
672+ def reranker_forward ( model , * args , ** kwargs ):
673673
674- llm_model .register_forward_hook (revert_padding_free_hook , with_kwargs = True , prepend = True )
674+ def inner_forward (* args , ** kwargs ):
675+ output = llm_model .forward (* args , ** kwargs )
676+ return revert_padding_free (output , kwargs , self .args .padding_side )
677+
678+ return transformers_seq_cls_forward (model , * args , origin_forward = inner_forward , ** kwargs )
679+
680+ model .forward = MethodType (reranker_forward , model )
675681 elif self .args .task_type == 'generative_reranker' :
676682 llm_model = get_llm_model (self .model , model_meta = self .model .model_meta )
677683
You can’t perform that action at this time.
0 commit comments