Skip to content

Commit b3e188b

Browse files
authored
add support for roberta ranking models
1 parent d98f2a3 commit b3e188b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3745,12 +3745,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
37453745
if name.startswith("cls.seq_relationship"):
37463746
return []
37473747

3748-
# For BertForSequenceClassification (direct projection layer)
3749-
if name == "classifier.weight":
3750-
name = "classifier.out_proj.weight"
3748+
if self.hparams.get("id2label"):
3749+
# For BertForSequenceClassification (direct projection layer)
3750+
if name == "classifier.weight":
3751+
name = "classifier.out_proj.weight"
37513752

3752-
if name == "classifier.bias":
3753-
name = "classifier.out_proj.bias"
3753+
if name == "classifier.bias":
3754+
name = "classifier.out_proj.bias"
37543755

37553756
return [(self.map_tensor_name(name), data_torch)]
37563757

@@ -3846,7 +3847,7 @@ def _xlmroberta_set_vocab(self) -> None:
38463847
self.gguf_writer.add_add_eos_token(True)
38473848

38483849

3849-
@ModelBase.register("RobertaModel")
3850+
@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
38503851
class RobertaModel(BertModel):
38513852
model_arch = gguf.MODEL_ARCH.BERT
38523853

0 commit comments

Comments
 (0)