Skip to content

Commit 8726f78

Browse files
author
Dimitar Tasev
committed
Clean-up old test functions
- Removes switch_mp_off/on, as I'm pretty sure that hasn't worked since `multiprocessing_necessary` was used - Rewrites a few tests that used switch mp off and on to now follow the new logic which makes the decision based on number of images - Removes multiprocessing_available
1 parent cd38d5d commit 8726f78

File tree

7 files changed

+39
-101
lines changed

7 files changed

+39
-101
lines changed

mantidimaging/core/gpu/test/gpu_test.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,6 @@ class GPUTest(unittest.TestCase):
2424
def __init__(self, *args, **kwargs):
2525
super(GPUTest, self).__init__(*args, **kwargs)
2626

27-
@staticmethod
28-
def run_serial(data, size, mode):
29-
"""
30-
Run the median filter in serial.
31-
"""
32-
th.switch_mp_off()
33-
cpu_result = MedianFilter.filter_func(data, size, mode)
34-
th.switch_mp_on()
35-
return cpu_result
36-
3727
@unittest.skipIf(GPU_NOT_AVAIL, reason=GPU_SKIP_REASON)
3828
def test_numpy_pad_modes_match_scipy_median_modes(self):
3929
"""
@@ -47,7 +37,7 @@ def test_numpy_pad_modes_match_scipy_median_modes(self):
4737
images = th.generate_shared_array()
4838

4939
gpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=False)
50-
cpu_result = self.run_serial(images.copy(), size, mode)
40+
cpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=True)
5141

5242
npt.assert_almost_equal(gpu_result[0], cpu_result[0])
5343

@@ -80,7 +70,7 @@ def test_gpu_result_matches_cpu_result_for_larger_images(self):
8070
images = th.generate_shared_array(shape=(20, N, N))
8171

8272
gpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=False)
83-
cpu_result = self.run_serial(images.copy(), size, mode)
73+
cpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=True)
8474

8575
npt.assert_almost_equal(gpu_result, cpu_result)
8676

@@ -95,7 +85,7 @@ def test_double_is_used_in_cuda_for_float_64_arrays(self):
9585
images = th.generate_shared_array(dtype="float64")
9686

9787
gpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=False)
98-
cpu_result = self.run_serial(images.copy(), size, mode)
88+
cpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=True)
9989

10090
npt.assert_almost_equal(gpu_result, cpu_result)
10191

@@ -115,7 +105,7 @@ def test_image_slicing_works(self):
115105
images = th.generate_shared_array(shape=(n_images, N, N))
116106

117107
gpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=False)
118-
cpu_result = self.run_serial(images.copy(), size, mode)
108+
cpu_result = MedianFilter.filter_func(images.copy(), size, mode, force_cpu=True)
119109

120110
npt.assert_almost_equal(gpu_result, cpu_result)
121111

mantidimaging/core/operations/median_filter/test/median_filter_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import numpy as np
88

99
import mantidimaging.test_helpers.unit_test_helper as th
10-
from mantidimaging.core.operations.median_filter import MedianFilter
10+
from mantidimaging.core.data.images import Images
1111
from mantidimaging.core.gpu import utility as gpu
12+
from mantidimaging.core.operations.median_filter import MedianFilter
1213
from mantidimaging.core.utility.memory_usage import get_memory_usage_linux
1314

1415
GPU_UTIL_LOC = "mantidimaging.core.gpu.utility.gpu_available"
@@ -57,17 +58,18 @@ def test_executed_no_helper_gpu(self):
5758

5859
th.assert_not_equals(result.data, original)
5960

60-
def test_executed_no_helper_seq(self):
61-
images = th.generate_images()
61+
def test_executed_seq(self):
62+
self.do_execute(th.generate_images())
6263

64+
def test_executed_par(self):
65+
self.do_execute(th.generate_images_for_parallel())
66+
67+
def do_execute(self, images: Images):
6368
size = 3
6469
mode = 'reflect'
6570

6671
original = np.copy(images.data[0])
67-
th.switch_mp_off()
6872
result = MedianFilter.filter_func(images, size, mode)
69-
th.switch_mp_on()
70-
7173
th.assert_not_equals(result.data, original)
7274

7375
def test_memory_change_acceptable(self):

mantidimaging/core/operations/rebin/test/rebin_test.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,13 @@ def test_executed_uniform_par_5(self):
4848
self.do_execute_uniform(5.0)
4949

5050
def test_executed_uniform_seq_2(self):
51-
th.switch_mp_off()
5251
self.do_execute_uniform(2.0)
53-
th.switch_mp_on()
5452

5553
def test_executed_uniform_seq_5(self):
56-
th.switch_mp_off()
5754
self.do_execute_uniform(5.0)
58-
th.switch_mp_on()
5955

6056
def test_executed_uniform_seq_5_int(self):
61-
th.switch_mp_off()
6257
self.do_execute_uniform(5.0, np.int32)
63-
th.switch_mp_on()
6458

6559
def do_execute_uniform(self, val=2.0, dtype=np.float32):
6660
images = th.generate_images(dtype=dtype)
@@ -78,31 +72,28 @@ def do_execute_uniform(self, val=2.0, dtype=np.float32):
7872
self.assertEqual(result.data.dtype, dtype)
7973

8074
def test_executed_xy_par_128_256(self):
81-
self.do_execute_xy((128, 256))
75+
self.do_execute_xy(True, (128, 256))
8276

8377
def test_executed_xy_par_512_256(self):
84-
self.do_execute_xy((512, 256))
78+
self.do_execute_xy(True, (512, 256))
8579

8680
def test_executed_xy_par_1024_1024(self):
87-
self.do_execute_xy((1024, 1024))
81+
self.do_execute_xy(True, (1024, 1024))
8882

8983
def test_executed_xy_seq_128_256(self):
90-
th.switch_mp_off()
91-
self.do_execute_xy((128, 256))
92-
th.switch_mp_on()
84+
self.do_execute_xy(False, (128, 256))
9385

9486
def test_executed_xy_seq_512_256(self):
95-
th.switch_mp_off()
96-
self.do_execute_xy((512, 256))
97-
th.switch_mp_on()
87+
self.do_execute_xy(False, (512, 256))
9888

9989
def test_executed_xy_seq_1024_1024(self):
100-
th.switch_mp_off()
101-
self.do_execute_xy((1024, 1024))
102-
th.switch_mp_on()
90+
self.do_execute_xy(False, (1024, 1024))
10391

104-
def do_execute_xy(self, val=(512, 512)):
105-
images = th.generate_images()
92+
def do_execute_xy(self, is_parallel: bool, val=(512, 512)):
93+
if is_parallel:
94+
images = th.generate_images((15, 8, 10))
95+
else:
96+
images = th.generate_images()
10697
mode = 'reflect'
10798

10899
expected_x = int(val[0])

mantidimaging/core/operations/roi_normalisation/test/roi_normalisation_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
# SPDX - License - Identifier: GPL-3.0-or-later
33

44
import unittest
5-
65
from unittest import mock
6+
77
import numpy as np
88
import numpy.testing as npt
99

1010
import mantidimaging.test_helpers.unit_test_helper as th
11+
from mantidimaging.core.data.images import Images
1112
from mantidimaging.core.operations.roi_normalisation import RoiNormalisationFilter
1213

1314

@@ -35,15 +36,12 @@ def test_not_executed_invalid_shape(self):
3536
npt.assert_raises(ValueError, RoiNormalisationFilter.filter_func, images, air)
3637

3738
def test_executed_par(self):
38-
self.do_execute()
39+
self.do_execute(th.generate_images_for_parallel())
3940

4041
def test_executed_seq(self):
41-
th.switch_mp_off()
42-
self.do_execute()
43-
th.switch_mp_on()
42+
self.do_execute(th.generate_images())
4443

45-
def do_execute(self):
46-
images = th.generate_images()
44+
def do_execute(self, images: Images):
4745

4846
original = np.copy(images.data[0])
4947
air = [3, 3, 4, 4]

mantidimaging/core/parallel/utility.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX - License - Identifier: GPL-3.0-or-later
33

44
import ctypes
5+
import multiprocessing
56
import os
67
from contextlib import contextmanager
78
from functools import partial
@@ -96,22 +97,8 @@ def temp_shared_array(shape, dtype: NP_DTYPE = np.float32) -> np.ndarray:
9697
pass
9798

9899

99-
def multiprocessing_available():
100-
try:
101-
# ignore error about unused import
102-
import multiprocessing # noqa: F401
103-
return multiprocessing
104-
except ImportError:
105-
return False
106-
107-
108100
def get_cores():
109-
mp = multiprocessing_available()
110-
# get max cores on the system as default
111-
if not mp:
112-
return 1
113-
else:
114-
return mp.cpu_count()
101+
return multiprocessing.cpu_count()
115102

116103

117104
def generate_indices(num_images):

mantidimaging/helper.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,3 @@ def check_data_stack(data, expected_dims=3, expected_class=Images):
5353

5454
if expected_dims != data.data.ndim:
5555
raise ValueError("Invalid data format. It does not have 3 dimensions. " "Shape: {0}".format(data.data.shape))
56-
57-
58-
def run_import_checks(config):
59-
"""
60-
Run the import checks to notify the user which features are available in
61-
the execution.
62-
"""
63-
from mantidimaging.core.parallel import utility as pu
64-
65-
log = logging.getLogger(__name__)
66-
if not pu.multiprocessing_available():
67-
log.info("Multiprocessing not available.")
68-
else:
69-
log.info("Running process on {0} cores.".format(config.func.cores))

mantidimaging/test_helpers/unit_test_helper.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ def generate_images(shape=g_shape, dtype=np.float32) -> Images:
4040
return _set_random_data(d, shape)
4141

4242

43+
def generate_images_for_parallel(shape=(15, 8, 10), dtype=np.float32) -> Images:
44+
"""
45+
Doesn't do anything special, just makes a number of images big enough to be
46+
ran in parallel from the logic of multiprocessing_necessary
47+
"""
48+
d = pu.create_array(shape, dtype)
49+
return _set_random_data(d, shape)
50+
51+
4352
def _set_random_data(data, shape):
4453
n = np.random.rand(*shape)
4554
# move the data in the shared array
@@ -90,31 +99,6 @@ def vsdebug():
9099
ptvsd.wait_for_attach()
91100

92101

93-
def switch_mp_off():
94-
"""
95-
This function does very bad things that should never be replicated.
96-
But it's a unit test so it's fine.
97-
"""
98-
# backup function so we can restore it
99-
global backup_mp_avail
100-
backup_mp_avail = pu.multiprocessing_available
101-
102-
def simple_return_false():
103-
return False
104-
105-
# do bad things, swap out the function to one that returns false
106-
pu.multiprocessing_available = simple_return_false
107-
108-
109-
def switch_mp_on():
110-
"""
111-
This function does very bad things that should never be replicated.
112-
But it's a unit test so it's fine.
113-
"""
114-
# restore the original backed up function from switch_mp_off
115-
pu.multiprocessing_available = backup_mp_avail
116-
117-
118102
def assert_files_exist(cls, base_name, file_extension, file_extension_separator='.', single_file=True, num_images=1):
119103
"""
120104
Asserts that a file exists.

0 commit comments

Comments
 (0)