Skip to content

Commit 336655e

Browse files
committed
fix comment, add additional check for pad token
1 parent bae0193 commit 336655e

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

model2vec/distill/distillation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,30 @@ def distill_from_model(
8585
if not all_tokens:
8686
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")
8787

88-
# Create the embeddings.
89-
unk_token: str | None = tokenizer.special_tokens_map.get("unk_token")
90-
pad_token: str | None = tokenizer.special_tokens_map.get("pad_token")
91-
92-
# Add the cleaned vocabulary to the tokenizer.
88+
unk_token = cast(str | None, tokenizer.special_tokens_map.get("unk_token"))
89+
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token"))
90+
91+
# Weird if to satsify mypy
92+
if pad_token is None:
93+
if unk_token is not None:
94+
pad_token = unk_token
95+
logger.warning(
96+
"The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
97+
)
98+
else:
99+
pad_token = unk_token or all_tokens[0].form
100+
logger.warning(
101+
"The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
102+
)
103+
104+
# Replace the vocabulary in the tokenizer with the new vocabulary.
93105
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
94106

107+
logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
95108
# Convert tokens to IDs
96109
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)
97110

111+
# Create the embeddings
98112
embeddings = create_embeddings(
99113
tokenized=token_ids, model=model, device=device, pad_token_id=tokenizer.get_vocab()[pad_token]
100114
)

0 commit comments

Comments
 (0)