Skip to content

Commit 5ea736b

Browse files
author
Dimitar Tasev
committed
Parametrise the multiprocessing_necessary test
1 parent 937ecb1 commit 5ea736b

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

mantidimaging/core/parallel/test/utility_test.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,32 @@
11
# Copyright (C) 2020 ISIS Rutherford Appleton Laboratory UKRI
22
# SPDX - License - Identifier: GPL-3.0-or-later
33

4+
from typing import List, Tuple, Union
45
from unittest import mock
56

7+
import pytest
8+
69
from mantidimaging.core.parallel.utility import execute_impl, multiprocessing_necessary
710

811

9-
def test_correctly_chooses_parallel():
10-
# forcing 1 core should always return False
11-
assert multiprocessing_necessary((100, 10, 10), cores=1) is False
12-
# shapes less than 10 should return false
13-
assert multiprocessing_necessary((10, 10, 10), cores=12) is False
14-
assert multiprocessing_necessary(10, cores=12) is False
15-
# shapes over 10 should return True
16-
assert multiprocessing_necessary((11, 10, 10), cores=12) is True
17-
assert multiprocessing_necessary(11, cores=12) is True
12+
@pytest.mark.parametrize(
13+
'shape,cores,should_be_parallel',
14+
(
15+
[(100, 10, 10), 1, False], # forcing 1 core should always return False
16+
# shapes <= 10 should return False
17+
[(10, 10, 10), 12, False],
18+
[10, 12, False],
19+
# shapes over 10 should return True
20+
[(11, 10, 10), 12, True],
21+
[11, 12, True],
22+
# repeat from above but with list, to cover that branch of the if
23+
[[100, 10, 10], 1, False],
24+
[[10, 10, 10], 12, False],
25+
[[11, 10, 10], 12, True],
26+
))
27+
def test_correctly_chooses_parallel(shape: Union[int, List, Tuple[int, int, int]], cores: int,
28+
should_be_parallel: bool):
29+
assert multiprocessing_necessary(shape, cores) is should_be_parallel
1830

1931

2032
@mock.patch('mantidimaging.core.parallel.utility.Pool')

mantidimaging/core/parallel/utility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# see https://github.com/mantidproject/mantidimaging/pull/762#issuecomment-741663482
1111
from multiprocessing import heap # type: ignore
1212
from multiprocessing.pool import Pool
13-
from typing import Any, Tuple, Type, Union
13+
from typing import Any, List, Tuple, Type, Union
1414

1515
import numpy as np
1616

@@ -97,7 +97,7 @@ def calculate_chunksize(cores):
9797
return 1
9898

9999

100-
def multiprocessing_necessary(shape: Union[int, Tuple[int, int, int]], cores) -> bool:
100+
def multiprocessing_necessary(shape: Union[int, Tuple[int, int, int], List], cores) -> bool:
101101
# This environment variable will be present when running PYDEVD from PyCharm
102102
# and that has the bug that multiprocessing Pools can never finish `.join()` ing
103103
# thus never actually finish their processing.

0 commit comments

Comments
 (0)