Skip to content

Commit bd723e9

Browse files
committed
feat: docstring with examples for mlknn scorer train
1 parent 8dd1033 commit bd723e9

File tree

1 file changed

+37
-0
lines changed
  • autointent/modules/scoring/_mlknn

1 file changed

+37
-0
lines changed

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,43 @@ class MLKnnScorer(ScoringModule):
5757
:ivar metadata: Metadata about the scorer's configuration.
5858
:ivar prebuilt_index: Flag indicating if the vector index is prebuilt.
5959
:ivar name: Name of the scorer, defaults to "mlknn".
60+
61+
Example
62+
--------
63+
Creating and fitting the MLKnnScorer:
64+
>>> from knn_scorer import MLKnnScorer
65+
>>> utterances = ["what is your name?", "how are you?"]
66+
>>> labels = [["greeting"], ["greeting"]]
67+
>>> scorer = MLKnnScorer(
68+
>>> k=5,
69+
>>> embedder_name="bert-base",
70+
>>> db_dir="/path/to/database",
71+
>>> s=1.0,
72+
>>> ignore_first_neighbours=0,
73+
>>> device="cuda",
74+
>>> batch_size=32,
75+
>>> max_length=128
76+
>>> )
77+
>>> scorer.fit(utterances, labels)
78+
79+
Predicting probabilities:
80+
>>> test_utterances = ["Hi!", "What's up?"]
81+
>>> probabilities = scorer.predict(test_utterances)
82+
>>> print(probabilities) # Outputs predicted probabilities for each label
83+
84+
Predicting labels:
85+
>>> predicted_labels = scorer.predict_labels(test_utterances, thresh=0.5)
86+
>>> print(predicted_labels) # Outputs binary array for each label prediction
87+
88+
Saving and loading the scorer:
89+
>>> scorer.dump("outputs/")
90+
>>> loaded_scorer = MLKnnScorer(
91+
>>> k=5,
92+
>>> embedder_name="bert-base",
93+
>>> db_dir="/path/to/database",
94+
>>> device="cuda"
95+
>>> )
96+
>>> loaded_scorer.load("outputs/")
6097
"""
6198

6299
arrays_filename: str = "probs.npz"

0 commit comments

Comments
 (0)