1919 DummySink ,
2020 ReadableDataSetSink ,
2121)
22+ from httomo .utils import save_2d_snapshot
2223from httomo .runner .gpu_utils import get_available_gpu_memory , gpumem_cleanup
2324from httomo .runner .monitoring_interface import MonitoringInterface
2425from httomo .runner .pipeline import Pipeline
@@ -49,11 +50,13 @@ def __init__(
4950 comm : MPI .Comm ,
5051 memory_limit_bytes : int = 0 ,
5152 monitor : Optional [MonitoringInterface ] = None ,
53+ save_snapshots : bool = False ,
5254 ):
5355 self .pipeline = pipeline
5456 self .reslice_dir = reslice_dir
5557 self .comm = comm
5658 self .monitor = monitor
59+ self .save_snapshots = save_snapshots
5760
5861 self .side_outputs : Dict [str , Any ] = dict ()
5962 self .source : Optional [DataSetSource ] = None
@@ -145,6 +148,7 @@ def _execute_section(self, section: Section, section_index: int = 0):
145148
146149 splitter = BlockSplitter (self .source , section .max_slices )
147150 no_of_blocks = len (splitter )
151+ section_length = len (section )
148152
149153 # Redirect tqdm progress bar output to /dev/null, and instead manually write block
150154 # processing progress to logfile within loop
@@ -160,8 +164,8 @@ def _execute_section(self, section: Section, section_index: int = 0):
160164 if self .monitor is not None :
161165 self .monitor .report_source_block (
162166 f"sec_{ section_index } " ,
163- section .methods [0 ].task_id if len ( section ) > 0 else "" ,
164- _get_slicing_dim ( section . pattern ) - 1 ,
167+ section .methods [0 ].task_id if section_length > 0 else "" ,
168+ slicing_dim_section ,
165169 block .shape ,
166170 block .chunk_index ,
167171 block .global_index ,
@@ -170,6 +174,23 @@ def _execute_section(self, section: Section, section_index: int = 0):
170174
171175 log_once (f" { str (progress )} " , level = logging .INFO )
172176 block = self ._execute_section_block (section , block )
177+ if (
178+ self .save_snapshots
179+ and self .comm .rank == self .comm .size // 2
180+ and idx == no_of_blocks // 2
181+ ):
182+ # save the 2D state-snapshot of the mid-data block from mid-cunk
183+ snapshot_slicer = [slice (None )] * block .data .ndim
184+ snapshot_slicer [slicing_dim_section ] = (
185+ np .shape (block .data )[slicing_dim_section ] // 2
186+ )
187+ snapshot_slice = block .data [tuple (snapshot_slicer )]
188+ method_to_snapshot_name = self ._get_methods_name_for_snapshot (section )
189+ save_2d_snapshot (
190+ snapshot_slice ,
191+ methods_name = method_to_snapshot_name ,
192+ section_index = section_index ,
193+ )
173194 log_rank (
174195 f" Finished processing block { idx + 1 } of { no_of_blocks } " ,
175196 comm = self .comm ,
@@ -181,7 +202,7 @@ def _execute_section(self, section: Section, section_index: int = 0):
181202 if self .monitor is not None :
182203 self .monitor .report_sink_block (
183204 f"sec_{ section_index } " ,
184- section .methods [- 1 ].task_id if len ( section ) > 0 else "" ,
205+ section .methods [- 1 ].task_id if section_length > 0 else "" ,
185206 _get_slicing_dim (section .pattern ) - 1 ,
186207 block .shape ,
187208 block .chunk_index ,
@@ -280,6 +301,21 @@ def _execute_section_block(
280301 if_previous_block_is_on_gpu = if_current_block_is_on_gpu
281302 return block
282303
304+ def _get_methods_name_for_snapshot (self , section : Section ) -> str :
305+ # iteratively checking if the method's name doesn't belong to irrelevant_method_names_snapshots
306+ irrelevant_method_names_snapshots = [
307+ "data_checker" ,
308+ "calculate_stats" ,
309+ "find_center_360" ,
310+ "find_center_pc" ,
311+ "find_center_vo" ,
312+ "save_intermediate_data" ,
313+ ]
314+ for wrapper in list (reversed (section .methods )):
315+ if wrapper .method_name not in irrelevant_method_names_snapshots :
316+ return wrapper .method_name
317+ raise ValueError ("Unable to find method name in section for snapshot saving" )
318+
283319 def _log_pipeline (self , msg : Any , level : int = logging .INFO ):
284320 log_once (msg , level = level )
285321
0 commit comments