Skip to content

Commit 3bbf042

Browse files
author
Dimitar Tasev
committed
Refactors flat fielding
1 parent 5ea736b commit 3bbf042

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

mantidimaging/core/operations/flat_fielding/flat_fielding.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from mantidimaging import helper as h
1111
from mantidimaging.core.data import Images
1212
from mantidimaging.core.operations.base_filter import BaseFilter, FilterGroup
13-
from mantidimaging.core.parallel import two_shared_mem as ptsm
14-
from mantidimaging.core.parallel import utility as pu
13+
from mantidimaging.core.parallel import utility as pu, shared as ps
1514
from mantidimaging.core.utility.progress_reporting import Progress
1615
from mantidimaging.gui.utility.qt_helpers import Type
1716
from mantidimaging.gui.widgets.stack_selector import StackSelectorWidgetView
@@ -58,7 +57,7 @@ class FlatFieldFilter(BaseFilter):
5857
filter_name = 'Flat-fielding'
5958

6059
@staticmethod
61-
def filter_func(data: Images,
60+
def filter_func(images: Images,
6261
flat_before: Images = None,
6362
flat_after: Images = None,
6463
dark_before: Images = None,
@@ -80,7 +79,7 @@ def filter_func(data: Images,
8079
:param chunksize: The number of chunks that each worker will receive.
8180
:return: Filtered data (stack of images)
8281
"""
83-
h.check_data_stack(data)
82+
h.check_data_stack(images)
8483

8584
if selected_flat_fielding is not None:
8685
if selected_flat_fielding == "Both, concatenated" and flat_after is not None and flat_before is not None \
@@ -101,19 +100,19 @@ def filter_func(data: Images,
101100
if 2 != flat_avg.ndim or 2 != dark_avg.ndim:
102101
raise ValueError(
103102
f"Incorrect shape of the flat image ({flat_avg.shape}) or dark image ({dark_avg.shape}) \
104-
which should match the shape of the sample images ({data.data.shape})")
103+
which should match the shape of the sample images ({images.data.shape})")
105104

106-
if not data.data.shape[1:] == flat_avg.shape == dark_avg.shape:
107-
raise ValueError(f"Not all images are the expected shape: {data.data.shape[1:]}, instead "
105+
if not images.data.shape[1:] == flat_avg.shape == dark_avg.shape:
106+
raise ValueError(f"Not all images are the expected shape: {images.data.shape[1:]}, instead "
108107
f"flat had shape: {flat_avg.shape}, and dark had shape: {dark_avg.shape}")
109108

110109
progress = Progress.ensure_instance(progress,
111-
num_steps=data.data.shape[0],
110+
num_steps=images.data.shape[0],
112111
task_name='Background Correction')
113-
_execute(data.data, flat_avg, dark_avg, cores, chunksize, progress)
112+
_execute(images.data, flat_avg, dark_avg, cores, chunksize, progress)
114113

115-
h.check_data_stack(data)
116-
return data
114+
h.check_data_stack(images)
115+
return images
117116

118117
@staticmethod
119118
def register_gui(form, on_change, view: FiltersWindowView) -> Dict[str, Any]:
@@ -260,7 +259,7 @@ def _subtract(data, dark=None):
260259
np.subtract(data, dark, out=data)
261260

262261

263-
def _execute(data, flat=None, dark=None, cores=None, chunksize=None, progress=None):
262+
def _execute(data: np.ndarray, flat=None, dark=None, cores=None, chunksize=None, progress=None):
264263
"""A benchmark justifying the current implementation, performed on
265264
500x2048x2048 images.
266265
@@ -289,11 +288,13 @@ def _execute(data, flat=None, dark=None, cores=None, chunksize=None, progress=No
289288
norm_divide[norm_divide == 0] = MINIMUM_PIXEL_VALUE
290289

291290
# subtract the dark from all images
292-
f = ptsm.create_partial(_subtract, fwd_function=ptsm.inplace_second_2d)
293-
data, dark = ptsm.execute(data, dark, f, cores, chunksize, progress=progress)
291+
do_subtract = ps.create_partial(_subtract, fwd_function=ps.inplace_second_2d)
292+
ps.shared_list = [data, dark]
293+
ps.execute(do_subtract, data.shape[0], progress, cores=cores)
294294

295295
# divide the data by (flat - dark)
296-
f = ptsm.create_partial(_divide, fwd_function=ptsm.inplace_second_2d)
297-
data, norm_divide = ptsm.execute(data, norm_divide, f, cores, chunksize, progress=progress)
296+
do_divide = ps.create_partial(_divide, fwd_function=ps.inplace_second_2d)
297+
ps.shared_list = [data, norm_divide]
298+
ps.execute(do_divide, data.shape[0], progress, cores=cores)
298299

299300
return data

mantidimaging/core/parallel/shared.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def return_to_self1(func, i, **kwargs):
2121
shared_list[0][i] = func(shared_list[0][i], **kwargs)
2222

2323

24+
def inplace_second_2d(func, i, **kwargs):
25+
func(shared_list[0][i], shared_list[1], **kwargs)
26+
27+
2428
def create_partial(func, fwd_function, **kwargs):
2529
"""
2630
Create a partial using functools.partial, to forward the kwargs to the

0 commit comments

Comments
 (0)