@@ -3061,6 +3061,64 @@ def prepare_tensors(self):
30613061class Qwen3Model (Qwen2Model ):
30623062 model_arch = gguf .MODEL_ARCH .QWEN3
30633063
3064+ # extra logic for rerank models
3065+ token_false_id : int | None = None
3066+ token_true_id : int | None = None
3067+ sep_token_id : int = 0
3068+ is_tied_embeddings : bool = False
3069+
3070+ def __init__ (self , * args , ** kwargs ):
3071+ super ().__init__ (* args , ** kwargs )
3072+ # a bit hacky, but currently the only way to detect if this is a rerank model
3073+ readme_path = self .dir_model / "README.md"
3074+ readme_text = ""
3075+ if readme_path .exists ():
3076+ with readme_path .open ("r" , encoding = "utf-8" ) as f :
3077+ readme_text = f .read ()
3078+ if "# Qwen3-Reranker" in readme_text :
3079+ self ._find_rerank_config ()
3080+
3081+ def _find_rerank_config (self ):
3082+ from transformers import AutoTokenizer
3083+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
3084+ self .token_false_id = tokenizer .convert_tokens_to_ids ("no" )
3085+ self .token_true_id = tokenizer .convert_tokens_to_ids ("yes" )
3086+ self .sep_token_id = tokenizer .convert_tokens_to_ids ("\\ n" ) # unused, but needed for rerank check
3087+ self .is_tied_embeddings = self .hparams .get ("tie_word_embeddings" , False )
3088+ logger .info (f"gguf: token_false_id = { self .token_false_id } , token_true_id = { self .token_true_id } " )
3089+ logger .info (f"gguf: sep_token_id = { self .sep_token_id } " )
3090+ logger .info (f"gguf: is_tied_embeddings = { self .is_tied_embeddings } " )
3091+
3092+ def set_gguf_parameters (self ):
3093+ super ().set_gguf_parameters ()
3094+ is_rerank = self .token_false_id is not None and self .token_true_id is not None
3095+ if is_rerank :
3096+ self .gguf_writer .add_pooling_type (gguf .PoolingType .RANK )
3097+ self .gguf_writer .add_sep_token_id (self .sep_token_id )
3098+ self .gguf_writer .add_uint32 (gguf .Keys .Classifier .OUTPUT_LABELS , 2 )
3099+
3100+ def _get_cls_out_tensor (self , data_torch : Tensor ) -> Tensor :
3101+ # extract "yes" and "no" tokens from the output lm_head tensor
3102+ assert self .token_false_id is not None and self .token_true_id is not None
3103+ false_row = data_torch [self .token_false_id ]
3104+ true_row = data_torch [self .token_true_id ]
3105+ return torch .stack ([true_row , false_row ], dim = 0 )
3106+
3107+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3108+ is_rerank = self .token_false_id is not None and self .token_true_id is not None
3109+
3110+ if is_rerank :
3111+ if self .is_tied_embeddings and "embed_tokens" in name :
3112+ return [
3113+ (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .CLS_OUT ] + ".weight" , self ._get_cls_out_tensor (data_torch )),
3114+ (self .map_tensor_name (name ), data_torch ),
3115+ ]
3116+ if not self .is_tied_embeddings and "lm_head" in name :
3117+ # this is the lm_head tensor, we need to extract the cls_out tensor
3118+ return [(gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .CLS_OUT ] + ".weight" , self ._get_cls_out_tensor (data_torch ))]
3119+
3120+ return super ().modify_tensors (data_torch , name , bid )
3121+
30643122
30653123@ModelBase .register ("Qwen3MoeForCausalLM" )
30663124class Qwen3MoeModel (Qwen2MoeModel ):
0 commit comments