Skip to content

Commit 861e873

Browse files
committed
Refactor tests
1 parent 6c74805 commit 861e873

File tree

1 file changed

+81
-61
lines changed

1 file changed

+81
-61
lines changed

tests/context/datahandler/test_data_handler.py

Lines changed: 81 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import pytest
22

3-
from pydantic import ValidationError
4-
53
from autointent import Dataset
64
from autointent.context.data_handler import DataHandler
75
from autointent.schemas import Sample
@@ -67,6 +65,10 @@ def sample_multilabel_data():
6765
}
6866

6967

68+
def mock_split():
69+
return [{"utterance": "Hello!", "label": 0}]
70+
71+
7072
def test_data_handler_initialization(sample_multiclass_data):
7173
handler = DataHandler(dataset=Dataset.from_dict(sample_multiclass_data), random_seed=42)
7274

@@ -94,63 +96,81 @@ def test_data_handler_multilabel_mode(sample_multilabel_data):
9496
assert handler.test_labels() == [[1, 0], [0, 1]]
9597

9698

97-
def test_sample_validation():
98-
utterance = "Hello!"
99-
Sample(utterance="Hello!", label=0)
100-
with pytest.raises(ValueError):
101-
Sample(utterance=utterance, label=[])
102-
with pytest.raises(ValueError):
103-
Sample(utterance=utterance, label=-1)
104-
with pytest.raises(ValueError):
105-
Sample(utterance=utterance, label=[-1])
106-
107-
108-
def test_dataset_validation():
109-
mock_split = [{"utterance": "Hello!", "label": 0}]
110-
111-
Dataset.from_dict({"train": mock_split})
112-
Dataset.from_dict({"train_0": mock_split, "train_1": mock_split})
113-
114-
with pytest.raises(ValueError):
115-
Dataset.from_dict({})
116-
with pytest.raises(ValueError):
117-
Dataset.from_dict({"train": mock_split, "train_0": mock_split, "train_1": mock_split})
118-
with pytest.raises(ValueError):
119-
Dataset.from_dict({"train": mock_split, "train_0": mock_split})
120-
with pytest.raises(ValueError):
121-
Dataset.from_dict({"train": mock_split, "train_1": mock_split})
122-
with pytest.raises(ValueError):
123-
Dataset.from_dict({"train_0": mock_split})
124-
with pytest.raises(ValueError):
125-
Dataset.from_dict({"train_1": mock_split})
126-
127-
Dataset.from_dict({"train": mock_split, "validation": mock_split})
128-
Dataset.from_dict({"train": mock_split, "validation_0": mock_split, "validation_1": mock_split})
129-
130-
with pytest.raises(ValueError):
131-
Dataset.from_dict(
132-
{"train": mock_split, "validation": mock_split, "validation_0": mock_split, "validation_1": mock_split},
133-
)
134-
with pytest.raises(ValueError):
135-
Dataset.from_dict({"train": mock_split, "validation": mock_split, "validation_0": mock_split})
136-
with pytest.raises(ValueError):
137-
Dataset.from_dict({"train": mock_split, "validation": mock_split, "validation_1": mock_split})
138-
with pytest.raises(ValueError):
139-
Dataset.from_dict({"train": mock_split, "validation_0": mock_split})
140-
with pytest.raises(ValueError):
141-
Dataset.from_dict({"train": mock_split, "validation_1": mock_split})
142-
143-
with pytest.raises(ValueError):
144-
Dataset.from_dict({"train": mock_split, "intents": [{"id": 1}]})
145-
with pytest.raises(ValueError):
146-
Dataset.from_dict({"train": [{"utterance": "Hello!", "label": 1}], "intents": [{"id": 0}]})
147-
148-
with pytest.raises(ValueError):
149-
Dataset.from_dict({"train": [{"utterance": "Hello!"}]})
99+
@pytest.mark.parametrize("label", [0, [0], None])
100+
def test_sample_initialization(label):
101+
sample = Sample(utterance="Hello!", label=label)
102+
assert sample.label == label
103+
104+
105+
@pytest.mark.parametrize("label", [-1, [-1], []])
106+
def test_sample_validation(label):
107+
with pytest.raises(ValueError):
108+
Sample(utterance="Hello!", label=label)
109+
110+
111+
@pytest.mark.parametrize(
112+
"mapping",
113+
[
114+
{"train": mock_split()},
115+
{"train": mock_split(), "test": mock_split()},
116+
{"train_0": mock_split(), "train_1": mock_split()},
117+
{"train_0": mock_split(), "train_1": mock_split(), "test": mock_split()},
118+
{"train": mock_split(), "validation": mock_split()},
119+
{"train": mock_split(), "validation": mock_split(), "test": mock_split()},
120+
{"train": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()},
121+
{"train": mock_split(), "validation_0": mock_split(), "validation_1": mock_split(), "test": mock_split()},
122+
{"train_0": mock_split(), "train_1": mock_split(), "validation": mock_split()},
123+
{"train_0": mock_split(), "train_1": mock_split(), "validation": mock_split(), "test": mock_split()},
124+
{"train_0": mock_split(), "train_1": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()},
125+
{
126+
"train_0": mock_split(),
127+
"train_1": mock_split(),
128+
"validation_0": mock_split(),
129+
"validation_1": mock_split(),
130+
"test": mock_split(),
131+
},
132+
]
133+
)
134+
def test_dataset_initialization(mapping):
135+
dataset = Dataset.from_dict(mapping)
136+
for split in mapping:
137+
assert split in dataset
138+
139+
140+
@pytest.mark.parametrize(
141+
"mapping",
142+
[
143+
{},
144+
{"train_0": mock_split()},
145+
{"train_1": mock_split()},
146+
{"train": mock_split(), "train_0": mock_split()},
147+
{"train": mock_split(), "train_1": mock_split()},
148+
{"train": mock_split(), "train_0": mock_split(), "train_1": mock_split()},
149+
{"train": mock_split(), "validation_0": mock_split()},
150+
{"train": mock_split(), "validation_1": mock_split()},
151+
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split()},
152+
{"train": mock_split(), "validation": mock_split(), "validation_1": mock_split()},
153+
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()},
154+
{"train": mock_split(), "oos": mock_split()}
155+
]
156+
)
157+
def test_dataset_validation(mapping):
158+
with pytest.raises(ValueError):
159+
Dataset.from_dict(mapping)
160+
161+
162+
@pytest.mark.parametrize(
163+
"mapping",
164+
[
165+
{"train": [{"utterance": "Hello!", "label": 0}], "intents": [{"id": 1}]},
166+
{"train": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}], "intents": [{"id": 0}]},
167+
{
168+
"train": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}],
169+
"test": [{"utterance": "Hello!", "label": 0}],
170+
},
171+
{"train": [{"utterance": "Hello!"}]},
172+
]
173+
)
174+
def test_intents_validation(mapping):
150175
with pytest.raises(ValueError):
151-
Dataset.from_dict(
152-
{"train": mock_split, "test": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}]},
153-
)
154-
155-
with pytest.raises(ValidationError):
156-
Dataset.from_dict({"train": mock_split, "oos": mock_split})
176+
Dataset.from_dict(mapping)

0 commit comments

Comments
 (0)