@@ -3690,14 +3690,20 @@ def __init__(self, *args, **kwargs):
36903690 super ().__init__ (* args , ** kwargs )
36913691 self .vocab_size = None
36923692
3693+ if cls_out_labels := self .hparams .get ("id2label" ):
3694+ if len (cls_out_labels ) == 2 and cls_out_labels [0 ] == "LABEL_0" :
3695+ # Remove dummy labels added by AutoConfig
3696+ cls_out_labels = None
3697+ self .cls_out_labels = cls_out_labels
3698+
36933699 def set_gguf_parameters (self ):
36943700 super ().set_gguf_parameters ()
36953701 self .gguf_writer .add_causal_attention (False )
36963702 self ._try_set_pooling_type ()
36973703
3698- if cls_out_labels := self .hparams . get ( "id2label" ) :
3704+ if self .cls_out_labels :
36993705 key_name = gguf .Keys .Classifier .OUTPUT_LABELS .format (arch = gguf .MODEL_ARCH_NAMES [self .model_arch ])
3700- self .gguf_writer .add_array (key_name , [v for k , v in sorted (cls_out_labels .items ())])
3706+ self .gguf_writer .add_array (key_name , [v for k , v in sorted (self . cls_out_labels .items ())])
37013707
37023708 def set_vocab (self ):
37033709 tokens , toktypes , tokpre = self .get_vocab_base ()
@@ -3749,7 +3755,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37493755 if name .startswith ("cls.seq_relationship" ):
37503756 return []
37513757
3752- if self .hparams . get ( "id2label" ) :
3758+ if self .cls_out_labels :
37533759 # For BertForSequenceClassification (direct projection layer)
37543760 if name == "classifier.weight" :
37553761 name = "classifier.out_proj.weight"
0 commit comments