Skip to content

Commit cb5270b

Browse files
committed
implement doctests cleanup
1 parent 7e092ca commit cb5270b

File tree

5 files changed

+44
-2
lines changed

5 files changed

+44
-2
lines changed

autointent/modules/scoring/_description/description.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,20 @@ class DescriptionScorer(ScoringModule):
4242
4343
Examples
4444
--------
45+
.. testsetup::
46+
47+
db_dir = "doctests-db"
48+
4549
.. testcode::
4650
4751
from autointent.modules import DescriptionScorer
4852
utterances = ["what is your name?", "how old are you?"]
4953
labels = [0, 1]
5054
descriptions = ["greeting", "age-related question"]
51-
scorer = DescriptionScorer(embedder_name="sergeyzh/rubert-tiny-turbo", temperature=1.0)
55+
scorer = DescriptionScorer(
56+
embedder_name="sergeyzh/rubert-tiny-turbo",
57+
db_dir=db_dir
58+
)
5259
scorer.fit(utterances, labels, descriptions)
5360
scores = scorer.predict(["tell me about your age?"])
5461
print(scores) # Outputs similarity scores for the utterance against all descriptions
@@ -57,6 +64,11 @@ class DescriptionScorer(ScoringModule):
5764
5865
[[0.47210786 0.5278922 ]]
5966
67+
.. testcleanup::
68+
69+
import shutil
70+
shutil.rmtree(db_dir)
71+
6072
"""
6173

6274
weights_file_name: str = "description_vectors.npy"

autointent/modules/scoring/_dnnc/dnnc.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class DNNCScorer(ScoringModule):
5757
5858
Examples
5959
--------
60+
.. testsetup::
61+
62+
db_dir = "doctests-db"
63+
6064
.. testcode::
6165
6266
from autointent.modules.scoring import DNNCScorer
@@ -66,6 +70,7 @@ class DNNCScorer(ScoringModule):
6670
cross_encoder_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
6771
embedder_name="sergeyzh/rubert-tiny-turbo",
6872
k=5,
73+
db_dir=db_dir,
6974
)
7075
scorer.fit(utterances, labels)
7176
@@ -79,6 +84,11 @@ class DNNCScorer(ScoringModule):
7984
[[-8.90408421 0. ]
8085
[-8.10923195 0. ]]
8186
87+
.. testcleanup::
88+
89+
import shutil
90+
shutil.rmtree(db_dir)
91+
8292
"""
8393

8494
name = "dnnc"

autointent/modules/scoring/_knn/knn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ class KNNScorer(ScoringModule):
4747
4848
Examples
4949
--------
50+
.. testsetup::
51+
52+
db_dir = "doctests-db"
53+
5054
.. testcode::
5155
5256
from autointent.modules.scoring import KNNScorer
@@ -55,6 +59,7 @@ class KNNScorer(ScoringModule):
5559
scorer = KNNScorer(
5660
embedder_name="sergeyzh/rubert-tiny-turbo",
5761
k=5,
62+
db_dir=db_dir,
5863
)
5964
scorer.fit(utterances, labels)
6065
test_utterances = ["hi", "what's up?"]
@@ -66,6 +71,11 @@ class KNNScorer(ScoringModule):
6671
[[0.67297815 0.32702185]
6772
[0.44031678 0.55968322]]
6873
74+
.. testcleanup::
75+
76+
import shutil
77+
shutil.rmtree(db_dir)
78+
6979
"""
7080

7181
weights: WEIGHT_TYPES

autointent/modules/scoring/_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
self,
7070
embedder_name: str,
7171
cv: int = 3,
72-
n_jobs: int = -1,
72+
n_jobs: int | None = None,
7373
device: str = "cpu",
7474
seed: int = 0,
7575
batch_size: int = 32,

autointent/modules/scoring/_mlknn/mlknn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class MLKnnScorer(ScoringModule):
5959
6060
Example
6161
--------
62+
.. testsetup::
63+
64+
db_dir = "doctests-db"
65+
6266
.. testcode::
6367
6468
from autointent.modules.scoring import MLKnnScorer
@@ -67,6 +71,7 @@ class MLKnnScorer(ScoringModule):
6771
scorer = MLKnnScorer(
6872
k=5,
6973
embedder_name="sergeyzh/rubert-tiny-turbo",
74+
db_dir=db_dir,
7075
)
7176
scorer.fit(utterances, labels)
7277
test_utterances = ["Hi!", "What's up?"]
@@ -78,6 +83,11 @@ class MLKnnScorer(ScoringModule):
7883
[[0.5 0.5]
7984
[0.5 0.5]]
8085
86+
.. testcleanup::
87+
88+
import shutil
89+
shutil.rmtree(db_dir)
90+
8191
"""
8292

8393
arrays_filename: str = "probs.npz"

0 commit comments

Comments
 (0)