|
| 1 | +import json |
| 2 | +import unittest |
| 3 | +import tempfile |
| 4 | +from pathlib import Path |
| 5 | +from unittest.mock import patch, MagicMock |
| 6 | + |
| 7 | +from ibllib.io.extractors import base |
| 8 | + |
| 9 | + |
| 10 | +class TestExtractorMaps(unittest.TestCase): |
| 11 | + """Tests for functions that return Bpod extractor classes.""" |
| 12 | + def setUp(self): |
| 13 | + # Store original __import__ |
| 14 | + self.orig_import = __import__ |
| 15 | + tmp = tempfile.TemporaryDirectory() |
| 16 | + self.addCleanup(tmp.cleanup) |
| 17 | + self.custom_extractors_path = Path(tmp.name).joinpath('task_extractor_map.json') |
| 18 | + self.custom_extractors = {'fooChoiceWorld': 'Bar'} |
| 19 | + self.projects = MagicMock() |
| 20 | + self.projects.base.__file__ = str(self.custom_extractors_path.with_name('__init__.py')) |
| 21 | + with open(self.custom_extractors_path, 'w') as fp: |
| 22 | + json.dump(self.custom_extractors, fp) |
| 23 | + |
| 24 | + def import_mock(self, name, *args): |
| 25 | + """Return mock for project_extraction imports.""" |
| 26 | + if name == 'projects' or name == 'projects.base': |
| 27 | + return self.projects |
| 28 | + return self.orig_import(name, *args) |
| 29 | + |
| 30 | + def test_get_task_extractor_map(self): |
| 31 | + """Test ibllib.io.extractors.base._get_task_extractor_map function.""" |
| 32 | + # Check the custom map is loaded |
| 33 | + with patch('builtins.__import__', side_effect=self.import_mock): |
| 34 | + extractors = base._get_task_extractor_map() |
| 35 | + self.assertTrue(self.custom_extractors.items() < extractors.items()) |
| 36 | + # Test handles case where module not installed |
| 37 | + with patch('builtins.__import__', side_effect=ModuleNotFoundError): |
| 38 | + extractors = base._get_task_extractor_map() |
| 39 | + self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items()))) |
| 40 | + # Remove the file and check exception is caught |
| 41 | + self.custom_extractors_path.unlink() |
| 42 | + extractors = base._get_task_extractor_map() |
| 43 | + self.assertFalse(set(self.custom_extractors.items()).issubset(set(extractors.items()))) |
| 44 | + |
| 45 | + def test_get_bpod_extractor_class(self): |
| 46 | + """Test ibllib.io.extractors.base.get_bpod_extractor_class function.""" |
| 47 | + # installe |
| 48 | + # alf_path = self.custom_extractors_path.parent.joinpath('subject', '2020-01-01', '001', 'raw_task_data_00') |
| 49 | + # alf_path.mkdir(parents=True) |
| 50 | + settings_file = Path(__file__).parent.joinpath( |
| 51 | + 'data', 'session_biased_ge5', 'raw_behavior_data', '_iblrig_taskSettings.raw.json' |
| 52 | + ) |
| 53 | + # shutil.copy(settings_file, alf_path) |
| 54 | + session_path = settings_file.parents[1] |
| 55 | + self.assertEqual('BiasedTrials', base.get_bpod_extractor_class(session_path)) |
| 56 | + session_path = str(session_path).replace('session_biased_ge5', 'session_training_ge5') |
| 57 | + self.assertEqual('TrainingTrials', base.get_bpod_extractor_class(session_path)) |
| 58 | + session_path = str(session_path).replace('session_training_ge5', 'foobar') |
| 59 | + self.assertRaises(ValueError, base.get_bpod_extractor_class, session_path) |
| 60 | + |
| 61 | + def test_protocol2extractor(self): |
| 62 | + """Test ibllib.io.extractors.base.protocol2extractor function.""" |
| 63 | + # Test fuzzy match |
| 64 | + (proc, expected), = self.custom_extractors.items() |
| 65 | + with patch('builtins.__import__', side_effect=self.import_mock): |
| 66 | + extractor = base.protocol2extractor('_mw_' + proc) |
| 67 | + self.assertEqual(expected, extractor) |
| 68 | + # Test unknown protocol |
| 69 | + self.assertRaises(ValueError, base.protocol2extractor, proc) |
| 70 | + |
| 71 | + |
| 72 | +if __name__ == '__main__': |
| 73 | + unittest.main() |
0 commit comments