Skip to content

Commit 6d2887b

Browse files
committed
Add model.encode and model.similarity to snippet
1 parent f9ae194 commit 6d2887b

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

packages/tasks/src/model-libraries-snippets.ts

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import type { ModelData } from "./model-data";
2-
import type { WidgetExampleTextInput } from "./widget-example";
2+
import type { WidgetExampleTextInput, WidgetExampleSentenceSimilarityInput } from "./widget-example";
33
import { LIBRARY_TASK_MAPPING } from "./library-to-tasks";
44

55
const TAG_CUSTOM_CODE = "custom_code";
@@ -704,13 +704,35 @@ export const sampleFactory = (model: ModelData): string[] => [
704704
`python -m sample_factory.huggingface.load_from_hub -r ${model.id} -d ./train_dir`,
705705
];
706706

707+
function get_widget_examples_from_st_model(model: ModelData): string[] | undefined {
708+
const widgetExample = model.widgetData?.[0] as WidgetExampleSentenceSimilarityInput | undefined;
709+
if (widgetExample) {
710+
return [widgetExample.source_sentence, ...widgetExample.sentences];
711+
}
712+
}
713+
707714
export const sentenceTransformers = (model: ModelData): string[] => {
708715
const remote_code_snippet = model.tags.includes(TAG_CUSTOM_CODE) ? ", trust_remote_code=True" : "";
716+
const exampleSentences = get_widget_examples_from_st_model(model) ?? [
717+
"The weather is lovely today.",
718+
"It's so sunny outside!",
719+
"He drove to the stadium.",
720+
];
709721

710722
return [
711723
`from sentence_transformers import SentenceTransformer
712724
713-
model = SentenceTransformer("${model.id}"${remote_code_snippet})`,
725+
# Download from the 🤗 Hub
726+
model = SentenceTransformer("${model.id}"${remote_code_snippet})
727+
728+
# Run inference
729+
texts = ${JSON.stringify(exampleSentences, null, 4)}
730+
embeddings = model.encode(texts)
731+
732+
# Get the similarity scores for the texts
733+
similarities = model.similarity(embeddings, embeddings)
734+
print(similarities.shape)
735+
# [${exampleSentences.length}, ${exampleSentences.length}]`,
714736
];
715737
};
716738

0 commit comments

Comments
 (0)