Skip to content

Commit ad69779

Browse files
authored
Add test for tags (#67)
1 parent 0a70364 commit ad69779

File tree

22 files changed

+374
-97
lines changed

22 files changed

+374
-97
lines changed

autointent/modules/prediction/_adaptive.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def dump(self, path: str) -> None:
146146
"""
147147
dump_dir = Path(path)
148148

149-
metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes)
149+
metadata = AdaptivePredictorDumpMetadata(
150+
r=self._r,
151+
tags=[t.model_dump() for t in self.tags] if self.tags else None, # type: ignore[misc]
152+
n_classes=self.n_classes,
153+
)
150154

151155
with (dump_dir / self.metadata_dict_name).open("w") as file:
152156
json.dump(metadata, file, indent=4)

autointent/modules/prediction/_threshold.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,11 @@ def dump(self, path: str) -> None:
164164
)
165165

166166
dump_dir = Path(path)
167+
metadata_json = self.metadata
168+
metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc]
167169

168170
with (dump_dir / self.metadata_dict_name).open("w") as file:
169-
json.dump(self.metadata, file, indent=4)
171+
json.dump(metadata_json, file, indent=4)
170172

171173
def load(self, path: str) -> None:
172174
"""

autointent/modules/prediction/_tunable.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,11 @@ def dump(self, path: str) -> None:
170170
)
171171

172172
dump_dir = Path(path)
173+
metadata_json = self.metadata
174+
metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc]
173175

174176
with (dump_dir / self.metadata_dict_name).open("w") as file:
175-
json.dump(self.metadata, file, indent=4)
177+
json.dump(metadata_json, file, indent=4)
176178

177179
def load(self, path: str) -> None:
178180
"""
@@ -183,9 +185,11 @@ def load(self, path: str) -> None:
183185
dump_dir = Path(path)
184186

185187
with (dump_dir / self.metadata_dict_name).open() as file:
186-
metadata: TunablePredictorDumpMetadata = json.load(file)
188+
metadata = json.load(file)
187189

188-
self.metadata = metadata
190+
metadata["tags"] = [Tag(**tag) for tag in metadata["tags"]] if metadata["tags"] else None
191+
192+
self.metadata: TunablePredictorDumpMetadata = metadata
189193
self.thresh = np.array(metadata["thresh"])
190194
self.multilabel = metadata["multilabel"]
191195
self.tags = metadata["tags"]

autointent/modules/prediction/_utils.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,31 @@ def apply_tags(labels: npt.NDArray[Any], scores: npt.NDArray[Any], tags: list[Ta
2121
:param tags: List of `Tag` objects, where each tag specifies mutually exclusive intent IDs.
2222
:return: Adjusted array of shape (n_samples, n_classes) with binary labels.
2323
"""
24-
n_samples, _ = labels.shape
25-
res = np.copy(labels)
26-
27-
for i in range(n_samples):
28-
sample_labels = labels[i].astype(bool)
29-
sample_scores = scores[i]
30-
31-
for tag in tags:
32-
if any(sample_labels[idx] for idx in tag.intent_ids):
33-
# Find the index of the class with the highest score among the tagged indices
34-
max_score_index = max(tag.intent_ids, key=lambda idx: sample_scores[idx])
35-
# Set all other tagged indices to 0 in the result
36-
for idx in tag.intent_ids:
37-
if idx != max_score_index:
38-
res[i, idx] = 0
39-
40-
return res
24+
labels = labels.copy()
25+
26+
for tag in tags:
27+
intent_ids = tag.intent_ids
28+
29+
labels_sub = labels[:, intent_ids]
30+
scores_sub = scores[:, intent_ids]
31+
32+
assigned = labels_sub == 1
33+
num_assigned = assigned.sum(axis=1)
34+
35+
assigned_scores = np.where(assigned, scores_sub, -np.inf)
36+
37+
samples_to_adjust = np.where(num_assigned > 1)[0]
38+
39+
if samples_to_adjust.size > 0:
40+
assigned_scores_adjust = assigned_scores[samples_to_adjust, :]
41+
idx_max_adjust = assigned_scores_adjust.argmax(axis=1)
42+
43+
labels_sub[samples_to_adjust, :] = 0
44+
labels_sub[samples_to_adjust, idx_max_adjust] = 1
45+
46+
labels[:, intent_ids] = labels_sub
47+
48+
return labels
4149

4250

4351
class WrongClassificationError(Exception):

tests/assets/configs/multiclass.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
metric: prediction_accuracy
3030
search_space:
3131
- module_type: threshold
32-
thresh: [0.5, [0.5, 0.5, 0.5]]
32+
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
3333
- module_type: tunable
3434
- module_type: argmax
3535
- module_type: jinoos

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@
2525
metric: prediction_accuracy
2626
search_space:
2727
- module_type: threshold
28-
thresh: [0.5, [0.5, 0.5, 0.5]]
28+
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
2929
- module_type: tunable
3030
- module_type: adaptive

tests/assets/data/clinc_subset.json

Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
"id": 2,
1515
"name": "alarm",
1616
"description": "User wants to set or manage an alarm."
17+
},
18+
{
19+
"id": 3,
20+
"name": "alarm reservation",
21+
"tags": ["alarm", "reservation"],
22+
"regexp_full_match": [],
23+
"regexp_partial_match": [],
24+
"description": "User wants to set or manage an alarm second time."
1725
}
1826
],
1927
"train": [
@@ -138,20 +146,125 @@
138146
"label": 2
139147
},
140148
{
141-
"utterance": "how much is an overdraft fee for bank"
149+
"utterance": "how much is an overdraft fee for bank",
150+
"label": 3
151+
},
152+
{
153+
"utterance": "where is the dipstick",
154+
"label": 3
155+
},
156+
{
157+
"utterance": "where is the dipstick",
158+
"label": 3
159+
},
160+
{
161+
"utterance": "where is the dipstick",
162+
"label": 3
163+
},
164+
{
165+
"utterance": "where is the dipstick",
166+
"label": 3
167+
},
168+
{
169+
"utterance": "where is the dipstick",
170+
"label": 3
171+
},
172+
{
173+
"utterance": "where is the dipstick",
174+
"label": 3
175+
},
176+
{
177+
"utterance": "how much is 1 share of aapl"
178+
},
179+
{
180+
"utterance": "how is glue made"
181+
},
182+
{
183+
"utterance": "how much is 1 share of aapl"
184+
},
185+
{
186+
"utterance": "how is glue made"
187+
},
188+
{
189+
"utterance": "how much is 1 share of aapl"
190+
},
191+
{
192+
"utterance": "how is glue made"
193+
},
194+
{
195+
"utterance": "how much is 1 share of aapl"
196+
},
197+
{
198+
"utterance": "how is glue made"
199+
}
200+
],
201+
"test": [
202+
{
203+
"utterance": "can i make a reservation for redrobin",
204+
"label": 0
205+
},
206+
{
207+
"utterance": "does redrobin do reservations",
208+
"label": 0
209+
},
210+
{
211+
"utterance": "does acero in maplewood allow reservations",
212+
"label": 0
213+
},
214+
{
215+
"utterance": "i think my account is blocked",
216+
"label": 1
217+
},
218+
{
219+
"utterance": "why is my bank account stopping all transactions from going through",
220+
"label": 1
221+
},
222+
{
223+
"utterance": "what would cause me to be locked out of my bank account",
224+
"label": 1
225+
},
226+
{
227+
"utterance": "find out the reason why am i locked out of my bank account",
228+
"label": 1
229+
},
230+
{
231+
"utterance": "make sure my alarm is set for three thirty in the morning",
232+
"label": 2
233+
},
234+
{
235+
"utterance": "please set an alarm for mid day",
236+
"label": 2
237+
},
238+
{
239+
"utterance": "have an alarm set for three in the morning",
240+
"label": 2
241+
},
242+
{
243+
"utterance": "set an alarm for me for 10:00 and another one set for 4:00",
244+
"label": 2
142245
},
143246
{
144-
"utterance": "why are exponents preformed before multiplication in the order of operations"
247+
"utterance": "set an alarm to go to sleep and another to wake up",
248+
"label": 2
145249
},
146250
{
147-
"utterance": "what size wipers does this car take"
251+
"utterance": "how much is an overdraft fee for bank",
252+
"label": 3
148253
},
149254
{
150-
"utterance": "where is the dipstick"
255+
"utterance": "why are exponents preformed before multiplication in the order of operations",
256+
"label": 3
257+
},
258+
{
259+
"utterance": "what size wipers does this car take",
260+
"label": 3
151261
},
152262
{
153263
"utterance": "how much is 1 share of aapl"
154264
},
265+
{
266+
"utterance": "how is glue made"
267+
},
155268
{
156269
"utterance": "how is glue made"
157270
}

tests/context/datahandler/test_stratificaiton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def test_train_test_split(dataset):
1212

1313
assert Split.TRAIN in dataset
1414
assert Split.TEST in dataset
15-
assert dataset[Split.TRAIN].num_rows == 24
16-
assert dataset[Split.TEST].num_rows == 6
15+
assert dataset[Split.TRAIN].num_rows == 29
16+
assert dataset[Split.TEST].num_rows == 8
1717
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)
1818

1919

@@ -28,6 +28,6 @@ def test_multilabel_train_test_split(dataset):
2828

2929
assert Split.TRAIN in dataset
3030
assert Split.TEST in dataset
31-
assert dataset[Split.TRAIN].num_rows == 24
32-
assert dataset[Split.TEST].num_rows == 6
31+
assert dataset[Split.TRAIN].num_rows == 30
32+
assert dataset[Split.TEST].num_rows == 7
3333
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)

tests/modules/prediction/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pytest
23

34
from autointent.context.data_handler import DataHandler
@@ -43,3 +44,8 @@ def multilabel_fit_data(dataset):
4344
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances(1))
4445
labels = data_handler.validation_labels(1) + [[0] * data_handler.n_classes] * len(data_handler.oos_utterances(1))
4546
return scores, labels
47+
48+
49+
@pytest.fixture
50+
def scores():
51+
return np.array([[0.05, 0.9, 0, 0.05], [0.8, 0, 0.1, 0.1], [0, 0.2, 0.7, 0.1]])

tests/modules/prediction/test_adaptive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
def test_multilabel(multilabel_fit_data):
99
predictor = AdaptivePredictor()
1010
predictor.fit(*multilabel_fit_data)
11-
scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]])
11+
scores = np.array([[0.2, 0.9, 0, 0], [0.8, 0, 0.6, 0], [0, 0.4, 0.7, 0]])
1212
predictions = predictor.predict(scores)
13-
desired = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]])
13+
desired = np.array([[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 1, 0]])
1414

1515
np.testing.assert_array_equal(predictions, desired)
1616

0 commit comments

Comments
 (0)