Skip to content

Commit 3b61fec

Browse files
committed
typing
1 parent b608081 commit 3b61fec

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

model2vec/distill/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_embeddings(
4646
:param pad_token_id: The pad token id. Used to pad sequences.
4747
:return: The output embeddings.
4848
"""
49-
model = model.to(device)
49+
model = model.to(device) # type: ignore # Transformers error
5050

5151
out_weights: np.ndarray
5252
intermediate_weights: list[np.ndarray] = []
@@ -98,7 +98,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
9898
"""
9999
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100100
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu()
101+
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # False positive
102102
# NOTE: If the dtype is bfloat 16, we convert to float32,
103103
# because numpy does not suport bfloat16
104104
# See here: https://github.com/numpy/numpy/issues/19808
@@ -153,7 +153,7 @@ def post_process_embeddings(
153153
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
154154
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
155155
proba = inv_rank / np.sum(inv_rank)
156-
weight = (sif_coefficient / (sif_coefficient + proba))
156+
weight = sif_coefficient / (sif_coefficient + proba)
157157
else:
158158
weight = np.ones(embeddings.shape[0])
159159

0 commit comments

Comments
 (0)