Skip to content

Commit 1a8569a

Browse files
committed
wip
1 parent 01317c0 commit 1a8569a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

model2vec/distill/distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def distill_from_model(
119119
km = KMeans(vocabulary_quantization, random_state=42)
120120
km.fit(embeddings)
121121
clustered_embeddings = km.predict(embeddings)
122-
mapping = {idx: x for idx, x in enumerate(clustered_embeddings)}
122+
mapping = {idx: int(x) for idx, x in enumerate(clustered_embeddings)}
123123

124124
embeddings = km.cluster_centers_
125125
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
@@ -128,7 +128,7 @@ def distill_from_model(
128128
embeddings, weights = post_process_embeddings(
129129
np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient
130130
)
131-
mapping = {idx: token.form for idx, token in enumerate(all_tokens)}
131+
mapping = {idx: idx for idx in range(len(all_tokens))}
132132
# Quantize the embeddings.
133133
embeddings = quantize_embeddings(embeddings, quantize_to)
134134

0 commit comments

Comments
 (0)