Skip to content

Commit d4ffdae

Browse files
authored
Merge pull request #668 from DiamondLightSource/remove_redundant_in_sweep
adding function to remove redundant methods in sweep pipeline
2 parents 48f570c + 1c6b192 commit d4ffdae

File tree

2 files changed

+109
-47
lines changed

2 files changed

+109
-47
lines changed

httomo/transform_layer.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ def __init__(
3636
self._out_dir = out_dir if out_dir is not None else httomo.globals.run_out_dir
3737

3838
def transform(self, pipeline: Pipeline) -> Pipeline:
39+
pipeline_is_sweep = _check_if_pipeline_has_a_sweep(pipeline)
40+
3941
pipeline = self.insert_data_reducer(pipeline)
42+
if pipeline_is_sweep:
43+
pipeline = self.remove_redundant_method_in_sweep(pipeline)
4044
pipeline = self.insert_data_checker(pipeline)
4145

42-
pipeline_is_sweep = _check_if_pipeline_has_a_sweep(pipeline)
43-
4446
if pipeline_is_sweep:
4547
pipeline = self.insert_save_images_after_sweep(pipeline)
4648
else:
@@ -95,30 +97,18 @@ def insert_data_checker(self, pipeline: Pipeline) -> Pipeline:
9597
"""This will insert CPU or GPU data checker method AFTER most of the methods in the pipeline"""
9698
loader = pipeline.loader
9799
methods = []
98-
methods.append(
99-
GenericMethodWrapper(
100-
self._repo,
101-
"httomolib.misc.utils",
102-
"data_checker",
103-
comm=self._comm,
104-
save_result=False,
105-
task_id="datachecker_0",
106-
infsnans_correct=False, # the input (raw) data is 16bit
107-
zeros_warning=True, # we count the zeros if the data is too sparse
108-
data_to_method_name="Data Loader",
109-
),
110-
)
111100
for index, m in enumerate(pipeline):
112101
methods.append(m)
113102
# handling some exceptions here after which we don't need to insert the data checker
103+
exceptions_methods = [
104+
"data_reducer",
105+
"data_checker",
106+
"calculate_stats",
107+
"rescale_to_int",
108+
"save_to_images",
109+
]
114110
if (
115-
m.method_name
116-
not in [
117-
"data_reducer",
118-
"data_checker",
119-
"calculate_stats",
120-
"rescale_to_int",
121-
]
111+
m.method_name not in exceptions_methods
122112
and "rotation" not in m.module_path
123113
and index < len(pipeline._methods) - 1
124114
):
@@ -152,7 +142,6 @@ def insert_data_checker(self, pipeline: Pipeline) -> Pipeline:
152142
data_to_method_name=m.method_name,
153143
),
154144
)
155-
156145
return Pipeline(loader, methods)
157146

158147
def insert_save_images_after_sweep(self, pipeline: Pipeline) -> Pipeline:
@@ -177,3 +166,13 @@ def insert_save_images_after_sweep(self, pipeline: Pipeline) -> Pipeline:
177166
)
178167
sweep_before = True
179168
return Pipeline(loader, methods)
169+
170+
def remove_redundant_method_in_sweep(self, pipeline: Pipeline) -> Pipeline:
171+
"""Remove "redundant" methods in the sweep pipeline that were inserted by the user."""
172+
redundant_methods = ["calculate_stats", "rescale_to_int", "save_to_images"]
173+
loader = pipeline.loader
174+
methods = []
175+
for m in pipeline:
176+
if m.method_name not in redundant_methods:
177+
methods.append(m)
178+
return Pipeline(loader, methods)

tests/test_transform_layer.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,70 @@ def test_insert_image_save_after_sweep(mocker: MockerFixture, tmp_path: Path):
186186
assert pipeline[3].config_params["axis"] == 1
187187

188188

189+
def test_remove_redundant_method_in_sweep(mocker: MockerFixture, tmp_path: Path):
190+
comm = MPI.COMM_SELF
191+
repo = make_mock_repo(mocker)
192+
loader = mocker.create_autospec(
193+
LoaderInterface,
194+
instance=True,
195+
)
196+
pipeline = Pipeline(
197+
loader=loader,
198+
methods=[
199+
make_test_method(
200+
mocker,
201+
method_name="remove_outlier",
202+
module_path="httomolibgpu.misc.corr",
203+
save_result=False,
204+
task_id="t1",
205+
),
206+
make_test_method(
207+
mocker,
208+
method_name="dark_flat_field_correction",
209+
module_path="httomolibgpu.prep.normalize",
210+
save_result=False,
211+
task_id="t2",
212+
),
213+
make_test_method(
214+
mocker,
215+
method_name="FBP3d_tomobar",
216+
module_path="httomolibgpu.recon.algorithm",
217+
save_result=False,
218+
sweep=True,
219+
task_id="t3",
220+
),
221+
make_test_method(
222+
mocker,
223+
method_name="calculate_stats",
224+
module_path="httomo.methods",
225+
save_result=False,
226+
task_id="t4",
227+
),
228+
make_test_method(
229+
mocker,
230+
method_name="rescale_to_int",
231+
module_path="httomolib.misc.rescale",
232+
save_result=False,
233+
task_id="t5",
234+
),
235+
make_test_method(
236+
mocker,
237+
method_name="save_to_images",
238+
module_path="httomolib.misc.images",
239+
save_result=False,
240+
task_id="t6",
241+
),
242+
],
243+
)
244+
trans = TransformLayer(comm, repo=repo, save_all=False, out_dir=tmp_path)
245+
pipeline = trans.remove_redundant_method_in_sweep(pipeline)
246+
247+
assert len(pipeline) == 3
248+
assert pipeline[0].method_name == "remove_outlier"
249+
assert pipeline[1].method_name == "dark_flat_field_correction"
250+
assert pipeline[2].method_name == "FBP3d_tomobar"
251+
252+
189253
def test_insert_data_checker(mocker: MockerFixture, tmp_path: Path):
190254
comm = MPI.COMM_SELF
191255
repo = make_mock_repo(mocker)
@@ -251,15 +315,14 @@ def test_insert_data_checker(mocker: MockerFixture, tmp_path: Path):
251315
trans = TransformLayer(comm, repo=repo, save_all=False, out_dir=tmp_path)
252316
pipeline = trans.insert_data_checker(pipeline)
253317

254-
assert len(pipeline) == 11
255-
assert pipeline[0].method_name == "data_checker"
256-
assert pipeline[2].method_name == "data_checker"
257-
assert pipeline[4].method_name == "normalize"
258-
assert pipeline[5].method_name == "data_checker"
259-
assert pipeline[6].method_name == "FBP3d_tomobar"
260-
assert pipeline[7].method_name == "data_checker"
261-
assert pipeline[9].method_name == "rescale_to_int"
262-
assert pipeline[10].method_name == "save_to_images"
318+
assert len(pipeline) == 10
319+
assert pipeline[1].method_name == "data_checker"
320+
assert pipeline[3].method_name == "normalize"
321+
assert pipeline[4].method_name == "data_checker"
322+
assert pipeline[5].method_name == "FBP3d_tomobar"
323+
assert pipeline[6].method_name == "data_checker"
324+
assert pipeline[8].method_name == "rescale_to_int"
325+
assert pipeline[9].method_name == "save_to_images"
263326

264327

265328
def test_insert_image_save_after_sweep2(mocker: MockerFixture, tmp_path: Path):
@@ -361,13 +424,13 @@ def test_insert_paganin_not_last_sweep(mocker: MockerFixture, tmp_path: Path):
361424
trans = TransformLayer(comm, repo=repo, save_all=False, out_dir=tmp_path)
362425
pipeline = trans.transform(pipeline)
363426

364-
assert len(pipeline) == 11
365-
assert pipeline[7].method_name == "save_to_images"
366-
assert pipeline[7].task_id == "saveimage_sweep_t3"
367-
assert pipeline[7].config_params["subfolder_name"] == "images_sweep_paganin_filter"
368-
assert pipeline[9].method_name == "FBP3d_tomobar"
369-
assert pipeline[10].task_id == "saveimage_sweep_t4"
370-
assert pipeline[10].config_params["subfolder_name"] == "images_sweep_FBP3d_tomobar"
427+
assert len(pipeline) == 10
428+
assert pipeline[6].method_name == "save_to_images"
429+
assert pipeline[6].task_id == "saveimage_sweep_t3"
430+
assert pipeline[6].config_params["subfolder_name"] == "images_sweep_paganin_filter"
431+
assert pipeline[8].method_name == "FBP3d_tomobar"
432+
assert pipeline[9].task_id == "saveimage_sweep_t4"
433+
assert pipeline[9].config_params["subfolder_name"] == "images_sweep_FBP3d_tomobar"
371434

372435

373436
def test_insert_paganin_is_last_sweep(mocker: MockerFixture, tmp_path: Path):
@@ -407,10 +470,10 @@ def test_insert_paganin_is_last_sweep(mocker: MockerFixture, tmp_path: Path):
407470
trans = TransformLayer(comm, repo=repo, save_all=False, out_dir=tmp_path)
408471
pipeline = trans.transform(pipeline)
409472

410-
assert len(pipeline) == 8
411-
assert pipeline[7].method_name == "save_to_images"
412-
assert pipeline[7].task_id == "saveimage_sweep_t3"
413-
assert pipeline[7].config_params["subfolder_name"] == "images_sweep_paganin_filter"
473+
assert len(pipeline) == 7
474+
assert pipeline[6].method_name == "save_to_images"
475+
assert pipeline[6].task_id == "saveimage_sweep_t3"
476+
assert pipeline[6].config_params["subfolder_name"] == "images_sweep_paganin_filter"
414477

415478

416479
def test_insert_denoise_last_after_FBP_sweep(mocker: MockerFixture, tmp_path: Path):
@@ -450,7 +513,7 @@ def test_insert_denoise_last_after_FBP_sweep(mocker: MockerFixture, tmp_path: Pa
450513
trans = TransformLayer(comm, repo=repo, save_all=False, out_dir=tmp_path)
451514
pipeline = trans.transform(pipeline)
452515

453-
assert len(pipeline) == 8
454-
assert pipeline[7].method_name == "save_to_images"
455-
assert pipeline[7].task_id == "saveimage_sweep_t3"
456-
assert pipeline[7].config_params["subfolder_name"] == "images_sweep_median_filter"
516+
assert len(pipeline) == 7
517+
assert pipeline[6].method_name == "save_to_images"
518+
assert pipeline[6].task_id == "saveimage_sweep_t3"
519+
assert pipeline[6].config_params["subfolder_name"] == "images_sweep_median_filter"

0 commit comments

Comments
 (0)