We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b41c3be commit 43c7e5fCopy full SHA for 43c7e5f
model2vec/distill/inference.py
@@ -159,7 +159,7 @@ def create_output_embeddings_from_model(
159
out = out.float()
160
161
# Add the output to the intermediate weights
162
- intermediate_weights.append(out[:, 1].detach().cpu().numpy())
+ intermediate_weights.append(out.mean(1).detach().cpu().numpy())
163
164
# Concatenate the intermediate weights
165
out_weights = np.concatenate(intermediate_weights)
0 commit comments