Skip to content

Commit e0a5808

Browse files
authored
Merge pull request #679 from DiamondLightSource/snapshot_saver
Snapshot saver feature to help with debugging
2 parents e5fb50d + 20d4713 commit e0a5808

File tree

6 files changed

+155
-6
lines changed

6 files changed

+155
-6
lines changed

docs/source/howto/run_httomo.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,15 @@ directory created by HTTomo would be
235235
Options/flags
236236
#############
237237

238-
The :code:`run` command has 16 options/flags:
238+
The :code:`run` command has 17 options/flags:
239239

240240
- :code:`--output-folder-name`
241241
- :code:`--save-all`
242242
- :code:`--gpu-id`
243243
- :code:`--reslice-dir`
244244
- :code:`--max-cpu-slices`
245245
- :code:`--max-memory`
246+
- :code:`--save-snapshots`
246247
- :code:`--monitor`
247248
- :code:`--monitor-output`
248249
- :code:`--intermediate-format`
@@ -364,6 +365,22 @@ The :code:`--max-memory` flag is for telling HTTomo how much RAM the machine has
364365
so then it can switch to using a file during execution of the pipeline if
365366
necessary.
366367

368+
:code:`--save-snapshots`
369+
~~~~~~~~~~~~~~~~~~~~~~~~
370+
371+
When this flag is enabled, the pipeline saves image snapshots at specific execution points.
372+
These snapshots are captured during selected methods - typically when a section boundary
373+
is reached and data is transferred to the CPU. At which time a slice of the data is saved for
374+
inspection.
375+
376+
This feature is particularly useful for complex pipelines (e.g. 360 degrees with stitching and phase contrast),
377+
where intermediate processing steps involved in reconstruction may unintentionally alter
378+
the data. By reviewing these snapshot images (JPEGs), users can more easily pinpoint
379+
where issues are introduced in the pipeline.
380+
381+
Enabling snapshots incurs almost no additional computational cost, unlike the :code:`--save-all`
382+
flag, which requires saving the entire dataset into a file for each method.
383+
367384
:code:`--monitor`
368385
~~~~~~~~~~~~~~~~~
369386

httomo/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def check(pipeline: Union[Path, str], in_data_file: Optional[Path] = None):
136136
default="0",
137137
help="Limit the amount of memory used by the pipeline to the given memory (supports strings like 3.2G or bytes)",
138138
)
139+
@click.option(
140+
"--save-snapshots",
141+
is_flag=True,
142+
help="Save intermediate images (snapshots) from some methods in the pipeline.",
143+
)
139144
@click.option(
140145
"--monitor",
141146
type=click.STRING,
@@ -211,6 +216,7 @@ def run(
211216
reslice_dir: Union[Path, None],
212217
max_cpu_slices: int,
213218
max_memory: str,
219+
save_snapshots: bool,
214220
monitor: List[str],
215221
pipeline_format: str,
216222
monitor_output: TextIO,
@@ -264,6 +270,7 @@ def run(
264270
monitor,
265271
monitor_output,
266272
reslice_dir,
273+
save_snapshots,
267274
)
268275
else:
269276
execute_sweep_run(pipeline, global_comm)
@@ -397,6 +404,7 @@ def execute_high_throughput_run(
397404
monitor: List[str],
398405
monitor_output: TextIO,
399406
reslice_dir: Union[Path, None],
407+
save_snapshots: bool,
400408
) -> None:
401409
# we use half the memory for blocks since we typically have inputs/output
402410
memory_limit = transform_limit_str_to_bytes(max_memory) // 2
@@ -415,6 +423,7 @@ def execute_high_throughput_run(
415423
global_comm,
416424
monitor=mon,
417425
memory_limit_bytes=memory_limit,
426+
save_snapshots=save_snapshots,
418427
)
419428
runner.execute()
420429
if mon is not None:

httomo/method_wrappers/images.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from mpi4py.MPI import Comm
1010

11-
1211
import os
1312
from typing import Dict, Optional
1413

httomo/runner/task_runner.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
DummySink,
2020
ReadableDataSetSink,
2121
)
22+
from httomo.utils import save_2d_snapshot
2223
from httomo.runner.gpu_utils import get_available_gpu_memory, gpumem_cleanup
2324
from httomo.runner.monitoring_interface import MonitoringInterface
2425
from httomo.runner.pipeline import Pipeline
@@ -49,11 +50,13 @@ def __init__(
4950
comm: MPI.Comm,
5051
memory_limit_bytes: int = 0,
5152
monitor: Optional[MonitoringInterface] = None,
53+
save_snapshots: bool = False,
5254
):
5355
self.pipeline = pipeline
5456
self.reslice_dir = reslice_dir
5557
self.comm = comm
5658
self.monitor = monitor
59+
self.save_snapshots = save_snapshots
5760

5861
self.side_outputs: Dict[str, Any] = dict()
5962
self.source: Optional[DataSetSource] = None
@@ -145,6 +148,7 @@ def _execute_section(self, section: Section, section_index: int = 0):
145148

146149
splitter = BlockSplitter(self.source, section.max_slices)
147150
no_of_blocks = len(splitter)
151+
section_length = len(section)
148152

149153
# Redirect tqdm progress bar output to /dev/null, and instead manually write block
150154
# processing progress to logfile within loop
@@ -160,8 +164,8 @@ def _execute_section(self, section: Section, section_index: int = 0):
160164
if self.monitor is not None:
161165
self.monitor.report_source_block(
162166
f"sec_{section_index}",
163-
section.methods[0].task_id if len(section) > 0 else "",
164-
_get_slicing_dim(section.pattern) - 1,
167+
section.methods[0].task_id if section_length > 0 else "",
168+
slicing_dim_section,
165169
block.shape,
166170
block.chunk_index,
167171
block.global_index,
@@ -170,6 +174,23 @@ def _execute_section(self, section: Section, section_index: int = 0):
170174

171175
log_once(f" {str(progress)}", level=logging.INFO)
172176
block = self._execute_section_block(section, block)
177+
if (
178+
self.save_snapshots
179+
and self.comm.rank == self.comm.size // 2
180+
and idx == no_of_blocks // 2
181+
):
182+
# save the 2D state-snapshot of the mid-data block from mid-cunk
183+
snapshot_slicer = [slice(None)] * block.data.ndim
184+
snapshot_slicer[slicing_dim_section] = (
185+
np.shape(block.data)[slicing_dim_section] // 2
186+
)
187+
snapshot_slice = block.data[tuple(snapshot_slicer)]
188+
method_to_snapshot_name = self._get_methods_name_for_snapshot(section)
189+
save_2d_snapshot(
190+
snapshot_slice,
191+
methods_name=method_to_snapshot_name,
192+
section_index=section_index,
193+
)
173194
log_rank(
174195
f" Finished processing block {idx + 1} of {no_of_blocks}",
175196
comm=self.comm,
@@ -181,7 +202,7 @@ def _execute_section(self, section: Section, section_index: int = 0):
181202
if self.monitor is not None:
182203
self.monitor.report_sink_block(
183204
f"sec_{section_index}",
184-
section.methods[-1].task_id if len(section) > 0 else "",
205+
section.methods[-1].task_id if section_length > 0 else "",
185206
_get_slicing_dim(section.pattern) - 1,
186207
block.shape,
187208
block.chunk_index,
@@ -280,6 +301,21 @@ def _execute_section_block(
280301
if_previous_block_is_on_gpu = if_current_block_is_on_gpu
281302
return block
282303

304+
def _get_methods_name_for_snapshot(self, section: Section) -> str:
305+
# iteratively checking if the method's name doesn't belong to irrelevant_method_names_snapshots
306+
irrelevant_method_names_snapshots = [
307+
"data_checker",
308+
"calculate_stats",
309+
"find_center_360",
310+
"find_center_pc",
311+
"find_center_vo",
312+
"save_intermediate_data",
313+
]
314+
for wrapper in list(reversed(section.methods)):
315+
if wrapper.method_name not in irrelevant_method_names_snapshots:
316+
return wrapper.method_name
317+
raise ValueError("Unable to find method name in section for snapshot saving")
318+
283319
def _log_pipeline(self, msg: Any, level: int = logging.INFO):
284320
log_once(msg, level=level)
285321

httomo/utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import sys
22
import logging
3-
from enum import Enum
43
from time import perf_counter_ns
54
from traceback import format_tb
65
from typing import Any, Callable, Dict, List, Literal, Tuple
76

87
from loguru import logger
98
from mpi4py import MPI
109
import numpy as np
10+
from PIL import Image
11+
import os
12+
import httomo.globals
13+
from pathlib import Path
14+
1115

1216
from httomo_backends.methods_database.query import Pattern
1317

@@ -25,6 +29,33 @@
2529
import numpy as xp
2630

2731

32+
def save_2d_snapshot(
33+
data_slice: xp.ndarray, methods_name: str, section_index: int
34+
) -> None:
35+
"""
36+
A utility to save stage snapshots as images to help debugging process
37+
38+
:param data_slice: a 2D array to save as a jpeg image
39+
:type data_slice: xp.ndarray
40+
:param methods_name: the name of the image to be saved (e.g. method's name)
41+
:type methods_name: str
42+
:param section_index: the index of the section
43+
:type section_index: int
44+
"""
45+
output_dir_snapshots = (
46+
Path(httomo.globals.run_out_dir) / "pipeline_stages_snapshots"
47+
)
48+
output_dir_snapshots.mkdir(parents=True, exist_ok=True)
49+
data_slice = np.nan_to_num(data_slice, copy=False, nan=0.0, posinf=0, neginf=0)
50+
vmin, vmax = np.percentile(data_slice, [1, 99])
51+
data_slice = np.clip(data_slice, vmin, vmax)
52+
data_slice = (data_slice - vmin) / (vmax - vmin)
53+
data_slice = (data_slice * 255).astype(np.uint8)
54+
filename = f"{0}{section_index}{methods_name}.{'jpeg'}"
55+
filepath_name = os.path.join(output_dir_snapshots, f"{filename}")
56+
Image.fromarray(data_slice, mode="L").save(filepath_name, quality=95)
57+
58+
2859
def log_once(output: Any, level: int = logging.INFO) -> None:
2960
"""
3061
Log output to console and log file if the process' global rank is zero.

tests/runner/test_task_runner.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,63 @@ def test_warns_with_multiple_reslices(
457457
assert "Data saving or/and reslicing operation will be performed 4 times" in args[0]
458458

459459

460+
def test_get_method_names_for_snapshot_saver(
461+
mocker: MockerFixture,
462+
dummy_block: DataSetBlock,
463+
tmp_path: PathLike,
464+
):
465+
loader = make_test_loader(mocker, block=dummy_block, pattern=Pattern.projection)
466+
method1 = make_test_method(mocker, method_name="m1", pattern=Pattern.projection)
467+
method2 = make_test_method(mocker, method_name="m2_rec", pattern=Pattern.projection)
468+
method3 = make_test_method(
469+
mocker, method_name="data_checker", pattern=Pattern.projection
470+
)
471+
method4 = make_test_method(
472+
mocker, method_name="find_center_pc", pattern=Pattern.sinogram
473+
)
474+
method5 = make_test_method(mocker, method_name="m5_rec", pattern=Pattern.sinogram)
475+
method6 = make_test_method(
476+
mocker, method_name="data_checker", pattern=Pattern.sinogram
477+
)
478+
method7 = make_test_method(mocker, method_name="m7_rec", pattern=Pattern.projection)
479+
method8 = make_test_method(
480+
mocker, method_name="data_checker", pattern=Pattern.projection
481+
)
482+
method9 = make_test_method(mocker, method_name="m9_rec", pattern=Pattern.sinogram)
483+
method10 = make_test_method(
484+
mocker, method_name="data_checker", pattern=Pattern.sinogram
485+
)
486+
method11 = make_test_method(
487+
mocker, method_name="calculate_stats", pattern=Pattern.all
488+
)
489+
p = Pipeline(
490+
loader=loader,
491+
methods=[
492+
method1,
493+
method2,
494+
method3,
495+
method4,
496+
method5,
497+
method6,
498+
method7,
499+
method8,
500+
method9,
501+
method10,
502+
method11,
503+
],
504+
)
505+
t = TaskRunner(p, reslice_dir=tmp_path, comm=MPI.COMM_WORLD)
506+
_sections = t._sectionize()
507+
508+
sections_number = len(_sections)
509+
METHODS_NAMES_EXPECTED = ["m2_rec", "m5_rec", "m7_rec", "m9_rec"]
510+
for ind in range(0, sections_number):
511+
assert (
512+
t._get_methods_name_for_snapshot(_sections[ind])
513+
== METHODS_NAMES_EXPECTED[ind]
514+
)
515+
516+
460517
def test_warns_with_multiple_stores_from_side_outputs(
461518
mocker: MockerFixture,
462519
dummy_block: DataSetBlock,

0 commit comments

Comments
 (0)