22import shutil
33import warnings
44from pathlib import Path
5- from typing import Optional , Union
5+ from typing import Dict , Optional , Union
66
77from auxiliary .nifti .io import read_nifti , write_nifti
8+
89from brainles_preprocessing .brain_extraction .brain_extractor import BrainExtractor
910from brainles_preprocessing .constants import PreprocessorSteps
1011from brainles_preprocessing .defacing import Defacer , QuickshearDefacer
1112from brainles_preprocessing .normalization .normalizer_base import Normalizer
13+ from brainles_preprocessing .registration import ( # TODO: this will throw warnings if ANTs or NiftyReg are not installed, not ideal
14+ ANTsRegistrator ,
15+ NiftyRegRegistrator ,
16+ )
1217from brainles_preprocessing .registration .registrator import Registrator
1318
1419logger = logging .getLogger (__name__ )
@@ -42,6 +47,7 @@ class Modality:
4247 normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
4348 bet (bool): Indicates whether brain extraction is enabled.
4449 atlas_correction (bool): Indicates whether atlas correction should be performed.
50+ coregistration_transform_path (str or None): Path to the coregistration transformation matrix, will be set after coregistration.
4551
4652 Example:
4753 >>> t1_modality = Modality(
@@ -73,6 +79,7 @@ def __init__(
7379 self .current = self .input_path
7480 self .normalizer = normalizer
7581 self .atlas_correction = atlas_correction
82+ self .transformation_paths : Dict [PreprocessorSteps , Path | None ] = {}
7683
7784 # Check that atleast one output is generated
7885 if not any (
@@ -128,6 +135,7 @@ def __init__(
128135 self .normalized_defaced_output_path = None
129136
130137 self .steps = {k : None for k in PreprocessorSteps }
138+ self .steps [PreprocessorSteps .INPUT ] = self .input_path
131139
132140 @property
133141 def bet (self ) -> bool :
@@ -197,6 +205,26 @@ def normalize(
197205 else :
198206 logger .info ("No normalizer specified; skipping normalization." )
199207
208+ def _find_transformation_matrix (
209+ self , transform_incomplete_path : Path
210+ ) -> Optional [Path ]:
211+ possible_Files = list (
212+ transform_incomplete_path .parent .glob (f"{ transform_incomplete_path .stem } .*" )
213+ )
214+ if len (possible_Files ) == 0 :
215+ logger .warning (
216+ f"No transformation matrix found for { transform_incomplete_path } . "
217+ "Returning None."
218+ )
219+ return None
220+ elif len (possible_Files ) > 1 :
221+ # TODO: Handle this case more gracefully, e.g., try to find proper extension based on the registrator
222+ logger .warning (
223+ f"Multiple transformation matrices found for { transform_incomplete_path } . "
224+ "Returning the first one."
225+ )
226+ return possible_Files [0 ]
227+
200228 def register (
201229 self ,
202230 registrator : Registrator ,
@@ -213,6 +241,7 @@ def register(
213241 fixed_image_path (str or Path): Path to the fixed image.
214242 registration_dir (str or Path): Directory to store registration results.
215243 moving_image_name (str): Name of the moving image.
244+ step (PreprocessorSteps): The current preprocessing step.
216245
217246 Returns:
218247 Path: Path to the registration matrix.
@@ -224,7 +253,7 @@ def register(
224253 registered_log = registration_dir / f"{ moving_image_name } .log"
225254
226255 # Note, add file ending depending on registration backend!
227- registered_matrix = registration_dir / f"{ moving_image_name } "
256+ registered_matrix = registration_dir / f"M_ { moving_image_name } "
228257
229258 registrator .register (
230259 fixed_image_path = fixed_image_path ,
@@ -235,6 +264,11 @@ def register(
235264 )
236265 self .current = registered
237266 self .steps [step ] = registered
267+
268+ self .transformation_paths [step ] = self ._find_transformation_matrix (
269+ transform_incomplete_path = registered_matrix
270+ )
271+
238272 return registered_matrix
239273
240274 def apply_bet_mask (
@@ -320,7 +354,7 @@ def transform(
320354 registration_dir_path (str or Path): Directory to store transformation results.
321355 moving_image_name (str): Name of the moving image.
322356 transformation_matrix_path (str or Path): Path to the transformation matrix.
323-
357+ step (PreprocessorSteps): The current preprocessing step.
324358 Returns:
325359 None
326360 """
@@ -331,15 +365,42 @@ def transform(
331365 transformed = registration_dir_path / f"{ moving_image_name } .nii.gz"
332366 transformed_log = registration_dir_path / f"{ moving_image_name } .log"
333367
334- registrator .transform (
335- fixed_image_path = fixed_image_path ,
336- moving_image_path = self .current ,
337- transformed_image_path = transformed ,
338- matrix_path = transformation_matrix_path ,
339- log_file_path = transformed_log ,
340- )
368+ if (
369+ isinstance (registrator , (ANTsRegistrator , NiftyRegRegistrator ))
370+ and step == PreprocessorSteps .ATLAS_REGISTERED
371+ ):
372+ # we test uniting transforms for these registrators
373+ assert (
374+ self .transformation_paths .get (PreprocessorSteps .COREGISTERED , None )
375+ is not None
376+ ), "Coregistration must be performed before applying atlas registration."
377+
378+ registrator .transform (
379+ fixed_image_path = fixed_image_path ,
380+ moving_image_path = self .steps [PreprocessorSteps .INPUT ],
381+ transformed_image_path = transformed ,
382+ matrix_path = [
383+ self .transformation_paths [
384+ PreprocessorSteps .COREGISTERED
385+ ], # coregistration matrix
386+ transformation_matrix_path , # atlas registration matrix
387+ ],
388+ log_file_path = transformed_log ,
389+ )
390+ else :
391+ registrator .transform (
392+ fixed_image_path = fixed_image_path ,
393+ moving_image_path = self .current ,
394+ transformed_image_path = transformed ,
395+ matrix_path = transformation_matrix_path ,
396+ log_file_path = transformed_log ,
397+ )
398+
341399 self .current = transformed
342400 self .steps [step ] = transformed
401+ self .transformation_paths [step ] = self ._find_transformation_matrix (
402+ transform_incomplete_path = transformation_matrix_path
403+ )
343404
344405 def extract_brain_region (
345406 self ,
0 commit comments