|
| 1 | +import tempfile |
| 2 | +from pathlib import Path |
| 3 | +import unittest |
| 4 | +from unittest import mock |
| 5 | +from textwrap import dedent |
| 6 | + |
1 | 7 | import numpy as np |
2 | 8 | import matplotlib.pyplot as plt |
3 | | -import unittest |
| 9 | + |
4 | 10 | from iblatlas.atlas import AllenAtlas, BrainRegions |
5 | 11 | from iblatlas.flatmaps import FlatMap |
6 | 12 | from iblatlas.plots import plot_swanson, annotate_swanson |
@@ -118,3 +124,44 @@ def test_load(self): |
118 | 124 | df_genes, gene_expression, atlas_agea = agea.load() |
119 | 125 | self.assertEqual(df_genes.shape[0], gene_expression.shape[0]) |
120 | 126 | self.assertEqual(gene_expression.shape[1:], (58, 41, 67)) |
| 127 | + |
| 128 | + |
| 129 | +class TestReadVolume(unittest.TestCase): |
| 130 | + |
| 131 | + @mock.patch('iblatlas.atlas._download_atlas_allen') |
| 132 | + def test_read_volume(self, mock_download): |
| 133 | + """Test that NRRD files are read corrrectly, and redownloaded when corrupted.""" |
| 134 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 135 | + tmpdir_path = Path(tmpdir) |
| 136 | + file_image = tmpdir_path / 'annotation_25.nrrd' |
| 137 | + # create dummy files |
| 138 | + file_image.write_bytes(b'\xFF') # invalid nrrd |
| 139 | + # mock the download to just write a dummy file |
| 140 | + mock_download.side_effect = self._side_effect |
| 141 | + with self.assertLogs('iblatlas.atlas', level='ERROR') as cm: |
| 142 | + volume = AllenAtlas._read_volume(file_image) |
| 143 | + self.assertIn('An error occured when loading atlas volumes', cm.output[0]) |
| 144 | + # the mock should have been called twice, once for each file |
| 145 | + self.assertEqual(mock_download.call_count, 1) |
| 146 | + self.assertEqual(volume.shape, (3, 3, 3)) |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def _side_effect(file_path): |
| 150 | + """Write a dummy file with correct header for nrrd.""" |
| 151 | + s = dedent( |
| 152 | + """\ |
| 153 | + NRRD0003 |
| 154 | + type: unsigned char |
| 155 | + dimension: 3 |
| 156 | + sizes: 3 3 3 |
| 157 | + spacings: 1.0458000000000001 1.0458000000000001 1.0458000000000001 |
| 158 | + kinds: domain domain domain |
| 159 | + encoding: ASCII\n\n |
| 160 | + """) |
| 161 | + s += '\n'.join(map(str, list(range(3**3)))) |
| 162 | + file_path.write_bytes(s.encode()) |
| 163 | + return file_path |
| 164 | + |
| 165 | + |
| 166 | +if __name__ == "__main__": |
| 167 | + unittest.main(exit=False, verbosity=2) |
0 commit comments