Skip to content

Commit 4d153ed

Browse files
committed
update convert script
1 parent 3286e76 commit 4d153ed

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

docs/lm_head_to_classifier/convert_lm.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
AutoTokenizer,
99
AutoModelForSequenceClassification,
1010
)
11-
11+
BASE_NAME = "michaelfeil"
1212

1313
@torch.no_grad()
1414
def convert_to_sequence_classifier(
@@ -81,14 +81,28 @@ def as_no_id_yes_id(
8181
assert len(no_id_yes_id[1]) == 1
8282
return no_id_yes_id[0][0], no_id_yes_id[1][0]
8383

84+
def only_yes_id(
85+
model_name: str = "mixedbread-ai/mxbai-rerank-base-v2",
86+
yes: str = "1",
87+
) -> tuple[int]:
88+
"""Get the id of the yes token."""
89+
tokenizer = AutoTokenizer.from_pretrained(model_name)
90+
yes_id = tokenizer(yes).input_ids
91+
assert len(yes_id) == 1
92+
return (yes_id[0],)
93+
8494

8595
def upload_and_convert(
8696
model_name: str = "mixedbread-ai/mxbai-rerank-base-v2",
8797
no: str = "0",
8898
yes: str = "1",
99+
uses_no_and_yes: bool = True,
89100
):
90101
"""Upload the converted sequence classifier to the hub."""
91-
no_id_yes_id = as_no_id_yes_id(model_name, yes, no)
102+
if not uses_no_and_yes:
103+
no_id_yes_id = only_yes_id(model_name, yes)
104+
else:
105+
no_id_yes_id = as_no_id_yes_id(model_name, yes, no)
92106
split_name = model_name.split("/")[1]
93107
model_cls = convert_to_sequence_classifier(f"{model_name}", no_id_yes_id)
94108
model_cls = model_cls.to(torch.float16)
@@ -99,9 +113,9 @@ def upload_and_convert(
99113
model_cls.save_pretrained(f"./{split_name}")
100114

101115
api = HfApi()
102-
api.create_repo(repo_id=f"michaelfeil/{split_name}-seq", exist_ok=True)
116+
api.create_repo(repo_id=f"{BASE_NAME}/{split_name}-seq", exist_ok=True)
103117
api.upload_folder(
104-
repo_id=f"michaelfeil/{split_name}-seq",
118+
repo_id=f"{BASE_NAME}/{split_name}-seq",
105119
folder_path=f"./{split_name}",
106120
)
107121

0 commit comments

Comments
 (0)