|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from pydantic import ValidationError |
4 | | - |
5 | 3 | from autointent import Dataset |
6 | 4 | from autointent.context.data_handler import DataHandler |
7 | 5 | from autointent.schemas import Sample |
@@ -67,6 +65,10 @@ def sample_multilabel_data(): |
67 | 65 | } |
68 | 66 |
|
69 | 67 |
|
| 68 | +def mock_split(): |
| 69 | + return [{"utterance": "Hello!", "label": 0}] |
| 70 | + |
| 71 | + |
70 | 72 | def test_data_handler_initialization(sample_multiclass_data): |
71 | 73 | handler = DataHandler(dataset=Dataset.from_dict(sample_multiclass_data), random_seed=42) |
72 | 74 |
|
@@ -94,63 +96,81 @@ def test_data_handler_multilabel_mode(sample_multilabel_data): |
94 | 96 | assert handler.test_labels() == [[1, 0], [0, 1]] |
95 | 97 |
|
96 | 98 |
|
97 | | -def test_sample_validation(): |
98 | | - utterance = "Hello!" |
99 | | - Sample(utterance="Hello!", label=0) |
100 | | - with pytest.raises(ValueError): |
101 | | - Sample(utterance=utterance, label=[]) |
102 | | - with pytest.raises(ValueError): |
103 | | - Sample(utterance=utterance, label=-1) |
104 | | - with pytest.raises(ValueError): |
105 | | - Sample(utterance=utterance, label=[-1]) |
106 | | - |
107 | | - |
108 | | -def test_dataset_validation(): |
109 | | - mock_split = [{"utterance": "Hello!", "label": 0}] |
110 | | - |
111 | | - Dataset.from_dict({"train": mock_split}) |
112 | | - Dataset.from_dict({"train_0": mock_split, "train_1": mock_split}) |
113 | | - |
114 | | - with pytest.raises(ValueError): |
115 | | - Dataset.from_dict({}) |
116 | | - with pytest.raises(ValueError): |
117 | | - Dataset.from_dict({"train": mock_split, "train_0": mock_split, "train_1": mock_split}) |
118 | | - with pytest.raises(ValueError): |
119 | | - Dataset.from_dict({"train": mock_split, "train_0": mock_split}) |
120 | | - with pytest.raises(ValueError): |
121 | | - Dataset.from_dict({"train": mock_split, "train_1": mock_split}) |
122 | | - with pytest.raises(ValueError): |
123 | | - Dataset.from_dict({"train_0": mock_split}) |
124 | | - with pytest.raises(ValueError): |
125 | | - Dataset.from_dict({"train_1": mock_split}) |
126 | | - |
127 | | - Dataset.from_dict({"train": mock_split, "validation": mock_split}) |
128 | | - Dataset.from_dict({"train": mock_split, "validation_0": mock_split, "validation_1": mock_split}) |
129 | | - |
130 | | - with pytest.raises(ValueError): |
131 | | - Dataset.from_dict( |
132 | | - {"train": mock_split, "validation": mock_split, "validation_0": mock_split, "validation_1": mock_split}, |
133 | | - ) |
134 | | - with pytest.raises(ValueError): |
135 | | - Dataset.from_dict({"train": mock_split, "validation": mock_split, "validation_0": mock_split}) |
136 | | - with pytest.raises(ValueError): |
137 | | - Dataset.from_dict({"train": mock_split, "validation": mock_split, "validation_1": mock_split}) |
138 | | - with pytest.raises(ValueError): |
139 | | - Dataset.from_dict({"train": mock_split, "validation_0": mock_split}) |
140 | | - with pytest.raises(ValueError): |
141 | | - Dataset.from_dict({"train": mock_split, "validation_1": mock_split}) |
142 | | - |
143 | | - with pytest.raises(ValueError): |
144 | | - Dataset.from_dict({"train": mock_split, "intents": [{"id": 1}]}) |
145 | | - with pytest.raises(ValueError): |
146 | | - Dataset.from_dict({"train": [{"utterance": "Hello!", "label": 1}], "intents": [{"id": 0}]}) |
147 | | - |
148 | | - with pytest.raises(ValueError): |
149 | | - Dataset.from_dict({"train": [{"utterance": "Hello!"}]}) |
| 99 | +@pytest.mark.parametrize("label", [0, [0], None]) |
| 100 | +def test_sample_initialization(label): |
| 101 | + sample = Sample(utterance="Hello!", label=label) |
| 102 | + assert sample.label == label |
| 103 | + |
| 104 | + |
| 105 | +@pytest.mark.parametrize("label", [-1, [-1], []]) |
| 106 | +def test_sample_validation(label): |
| 107 | + with pytest.raises(ValueError): |
| 108 | + Sample(utterance="Hello!", label=label) |
| 109 | + |
| 110 | + |
| 111 | +@pytest.mark.parametrize( |
| 112 | + "mapping", |
| 113 | + [ |
| 114 | + {"train": mock_split()}, |
| 115 | + {"train": mock_split(), "test": mock_split()}, |
| 116 | + {"train_0": mock_split(), "train_1": mock_split()}, |
| 117 | + {"train_0": mock_split(), "train_1": mock_split(), "test": mock_split()}, |
| 118 | + {"train": mock_split(), "validation": mock_split()}, |
| 119 | + {"train": mock_split(), "validation": mock_split(), "test": mock_split()}, |
| 120 | + {"train": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()}, |
| 121 | + {"train": mock_split(), "validation_0": mock_split(), "validation_1": mock_split(), "test": mock_split()}, |
| 122 | + {"train_0": mock_split(), "train_1": mock_split(), "validation": mock_split()}, |
| 123 | + {"train_0": mock_split(), "train_1": mock_split(), "validation": mock_split(), "test": mock_split()}, |
| 124 | + {"train_0": mock_split(), "train_1": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()}, |
| 125 | + { |
| 126 | + "train_0": mock_split(), |
| 127 | + "train_1": mock_split(), |
| 128 | + "validation_0": mock_split(), |
| 129 | + "validation_1": mock_split(), |
| 130 | + "test": mock_split(), |
| 131 | + }, |
| 132 | + ] |
| 133 | +) |
| 134 | +def test_dataset_initialization(mapping): |
| 135 | + dataset = Dataset.from_dict(mapping) |
| 136 | + for split in mapping: |
| 137 | + assert split in dataset |
| 138 | + |
| 139 | + |
| 140 | +@pytest.mark.parametrize( |
| 141 | + "mapping", |
| 142 | + [ |
| 143 | + {}, |
| 144 | + {"train_0": mock_split()}, |
| 145 | + {"train_1": mock_split()}, |
| 146 | + {"train": mock_split(), "train_0": mock_split()}, |
| 147 | + {"train": mock_split(), "train_1": mock_split()}, |
| 148 | + {"train": mock_split(), "train_0": mock_split(), "train_1": mock_split()}, |
| 149 | + {"train": mock_split(), "validation_0": mock_split()}, |
| 150 | + {"train": mock_split(), "validation_1": mock_split()}, |
| 151 | + {"train": mock_split(), "validation": mock_split(), "validation_0": mock_split()}, |
| 152 | + {"train": mock_split(), "validation": mock_split(), "validation_1": mock_split()}, |
| 153 | + {"train": mock_split(), "validation": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()}, |
| 154 | + {"train": mock_split(), "oos": mock_split()} |
| 155 | + ] |
| 156 | +) |
| 157 | +def test_dataset_validation(mapping): |
| 158 | + with pytest.raises(ValueError): |
| 159 | + Dataset.from_dict(mapping) |
| 160 | + |
| 161 | + |
| 162 | +@pytest.mark.parametrize( |
| 163 | + "mapping", |
| 164 | + [ |
| 165 | + {"train": [{"utterance": "Hello!", "label": 0}], "intents": [{"id": 1}]}, |
| 166 | + {"train": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}], "intents": [{"id": 0}]}, |
| 167 | + { |
| 168 | + "train": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}], |
| 169 | + "test": [{"utterance": "Hello!", "label": 0}], |
| 170 | + }, |
| 171 | + {"train": [{"utterance": "Hello!"}]}, |
| 172 | + ] |
| 173 | +) |
| 174 | +def test_intents_validation(mapping): |
150 | 175 | with pytest.raises(ValueError): |
151 | | - Dataset.from_dict( |
152 | | - {"train": mock_split, "test": [{"utterance": "Hello!", "label": 0}, {"utterance": "Hello!", "label": 1}]}, |
153 | | - ) |
154 | | - |
155 | | - with pytest.raises(ValidationError): |
156 | | - Dataset.from_dict({"train": mock_split, "oos": mock_split}) |
| 176 | + Dataset.from_dict(mapping) |
0 commit comments