|
15 | 15 | import pytest |
16 | 16 |
|
17 | 17 | from opentau.configs.default import DatasetConfig, DatasetMixtureConfig |
18 | | -from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING, LOSS_TYPE_MAPPING |
| 18 | +from opentau.datasets.standard_data_format_mapping import DATA_FEATURES_NAME_MAPPING |
19 | 19 |
|
20 | 20 |
|
21 | 21 | @pytest.mark.parametrize( |
@@ -94,121 +94,9 @@ def setup_method(self): |
94 | 94 | """Set up test fixtures before each test method.""" |
95 | 95 | # Store original state of global mappings |
96 | 96 | self.original_data_mapping = DATA_FEATURES_NAME_MAPPING.copy() |
97 | | - self.original_loss_mapping = LOSS_TYPE_MAPPING.copy() |
98 | 97 |
|
99 | 98 | def teardown_method(self): |
100 | 99 | """Clean up after each test method.""" |
101 | 100 | # Restore original state of global mappings |
102 | 101 | DATA_FEATURES_NAME_MAPPING.clear() |
103 | 102 | DATA_FEATURES_NAME_MAPPING.update(self.original_data_mapping) |
104 | | - LOSS_TYPE_MAPPING.clear() |
105 | | - LOSS_TYPE_MAPPING.update(self.original_loss_mapping) |
106 | | - |
107 | | - @pytest.mark.parametrize( |
108 | | - "data_mapping, loss_mapping, should_raise", |
109 | | - [ |
110 | | - (None, None, False), # Both None - valid |
111 | | - ({"camera0": "image"}, "MSE", False), # Both provided - valid |
112 | | - (None, "MSE", True), # Only loss_mapping provided - invalid |
113 | | - ({"camera0": "image"}, None, True), # Only data_mapping provided - invalid |
114 | | - ], |
115 | | - ) |
116 | | - def test_data_mapping_validation(self, data_mapping, loss_mapping, should_raise): |
117 | | - """Test that data_features_name_mapping and loss_type_mapping must be provided together.""" |
118 | | - if should_raise: |
119 | | - with pytest.raises( |
120 | | - ValueError, |
121 | | - match="`data_features_name_mapping` and `loss_type_mapping` have to be provided together.", |
122 | | - ): |
123 | | - DatasetConfig( |
124 | | - repo_id="test_repo", |
125 | | - data_features_name_mapping=data_mapping, |
126 | | - loss_type_mapping=loss_mapping, |
127 | | - ) |
128 | | - else: |
129 | | - # Should not raise an error |
130 | | - DatasetConfig( |
131 | | - repo_id="test_repo", data_features_name_mapping=data_mapping, loss_type_mapping=loss_mapping |
132 | | - ) |
133 | | - |
134 | | - def test_mapping_addition_to_global_dicts(self): |
135 | | - """Test that mappings are added to global dictionaries when both are provided.""" |
136 | | - test_repo_id = "test_custom_repo" |
137 | | - test_data_mapping = {"camera0": "observation.image", "state": "observation.state"} |
138 | | - test_loss_mapping = "MSE" |
139 | | - |
140 | | - # Ensure the repo_id is not already in the mappings |
141 | | - assert test_repo_id not in DATA_FEATURES_NAME_MAPPING |
142 | | - assert test_repo_id not in LOSS_TYPE_MAPPING |
143 | | - |
144 | | - # Create DatasetConfig with both mappings |
145 | | - config = DatasetConfig( # noqa: F841 |
146 | | - repo_id=test_repo_id, |
147 | | - data_features_name_mapping=test_data_mapping, |
148 | | - loss_type_mapping=test_loss_mapping, |
149 | | - ) |
150 | | - |
151 | | - # Check that mappings were added to global dictionaries |
152 | | - assert test_repo_id in DATA_FEATURES_NAME_MAPPING |
153 | | - assert test_repo_id in LOSS_TYPE_MAPPING |
154 | | - assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == test_data_mapping |
155 | | - assert LOSS_TYPE_MAPPING[test_repo_id] == test_loss_mapping |
156 | | - |
157 | | - def test_mapping_not_added_when_both_none(self): |
158 | | - """Test that mappings are not added to global dictionaries when both are None.""" |
159 | | - test_repo_id = "test_none_repo" |
160 | | - |
161 | | - # Ensure the repo_id is not already in the mappings |
162 | | - assert test_repo_id not in DATA_FEATURES_NAME_MAPPING |
163 | | - assert test_repo_id not in LOSS_TYPE_MAPPING |
164 | | - |
165 | | - # Create DatasetConfig with both mappings as None |
166 | | - config = DatasetConfig(repo_id=test_repo_id, data_features_name_mapping=None, loss_type_mapping=None) # noqa: F841 |
167 | | - |
168 | | - # Check that mappings were not added to global dictionaries |
169 | | - assert test_repo_id not in DATA_FEATURES_NAME_MAPPING |
170 | | - assert test_repo_id not in LOSS_TYPE_MAPPING |
171 | | - |
172 | | - def test_mapping_overwrites_existing(self): |
173 | | - """Test that providing mappings overwrites existing entries for the same repo_id.""" |
174 | | - test_repo_id = "test_overwrite_repo" |
175 | | - original_data_mapping = {"old": "mapping"} |
176 | | - original_loss_mapping = "CE" |
177 | | - new_data_mapping = {"camera0": "observation.image", "state": "observation.state"} |
178 | | - new_loss_mapping = "MSE" |
179 | | - |
180 | | - # Add original mappings |
181 | | - DATA_FEATURES_NAME_MAPPING[test_repo_id] = original_data_mapping |
182 | | - LOSS_TYPE_MAPPING[test_repo_id] = original_loss_mapping |
183 | | - |
184 | | - # Create DatasetConfig with new mappings |
185 | | - config = DatasetConfig( # noqa: F841 |
186 | | - repo_id=test_repo_id, |
187 | | - data_features_name_mapping=new_data_mapping, |
188 | | - loss_type_mapping=new_loss_mapping, |
189 | | - ) |
190 | | - |
191 | | - # Check that mappings were overwritten |
192 | | - assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == new_data_mapping |
193 | | - assert LOSS_TYPE_MAPPING[test_repo_id] == new_loss_mapping |
194 | | - assert DATA_FEATURES_NAME_MAPPING[test_repo_id] != original_data_mapping |
195 | | - assert LOSS_TYPE_MAPPING[test_repo_id] != original_loss_mapping |
196 | | - |
197 | | - def test_empty_mappings(self): |
198 | | - """Test behavior with empty mappings.""" |
199 | | - test_repo_id = "test_empty_repo" |
200 | | - empty_data_mapping = {} |
201 | | - test_loss_mapping = "MSE" |
202 | | - |
203 | | - # Create DatasetConfig with empty data mapping |
204 | | - config = DatasetConfig( # noqa: F841 |
205 | | - repo_id=test_repo_id, |
206 | | - data_features_name_mapping=empty_data_mapping, |
207 | | - loss_type_mapping=test_loss_mapping, |
208 | | - ) |
209 | | - |
210 | | - # Check that empty mapping was added |
211 | | - assert test_repo_id in DATA_FEATURES_NAME_MAPPING |
212 | | - assert test_repo_id in LOSS_TYPE_MAPPING |
213 | | - assert DATA_FEATURES_NAME_MAPPING[test_repo_id] == empty_data_mapping |
214 | | - assert LOSS_TYPE_MAPPING[test_repo_id] == test_loss_mapping |
0 commit comments