Skip to content

Commit 8dd1033

Browse files
committed
feat: docstring with examples for knn scorer train
1 parent 0bde8d7 commit 8dd1033

File tree

1 file changed

+33
-0
lines changed
  • autointent/modules/scoring/_knn

1 file changed

+33
-0
lines changed

autointent/modules/scoring/_knn/knn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,39 @@ class KNNScorer(ScoringModule):
4545
:ivar _vector_index: VectorIndex instance for neighbor retrieval.
4646
:ivar name: Name of the scorer, defaults to "knn".
4747
:ivar prebuilt_index: Flag indicating if the vector index is prebuilt.
48+
49+
Examples
50+
--------
51+
Creating and fitting the KNNScorer:
52+
>>> from autointent.modules import KNNScorer
53+
>>> utterances = ["hello", "how are you?"]
54+
>>> labels = ["greeting", "greeting"]
55+
>>> scorer = KNNScorer(
56+
>>> embedder_name="bert-base",
57+
>>> k=5,
58+
>>> weights="distance",
59+
>>> db_dir="/path/to/database",
60+
>>> device="cuda",
61+
>>> batch_size=32,
62+
>>> max_length=128
63+
>>> )
64+
>>> scorer.fit(utterances, labels)
65+
66+
Predicting class probabilities:
67+
>>> test_utterances = ["hi", "what's up?"]
68+
>>> probabilities = scorer.predict(test_utterances)
69+
>>> print(probabilities) # Outputs predicted class probabilities for the utterances
70+
71+
Saving and loading the scorer:
72+
>>> scorer.dump("outputs/")
73+
>>> loaded_scorer = KNNScorer(
74+
>>> embedder_name="bert-base",
75+
>>> k=5,
76+
>>> weights="distance",
77+
>>> db_dir="/path/to/database",
78+
>>> device="cuda"
79+
>>> )
80+
>>> loaded_scorer.load("outputs/")
4881
"""
4982

5083
weights: WEIGHT_TYPES

0 commit comments

Comments
 (0)