Skip to content

Commit 1413eb9

Browse files
authored
Merge pull request #131 from BrainLesion/92-unite-registration-steps-into-one-transformation
92 unite registration steps into one transformation
2 parents 960edb3 + eddaca1 commit 1413eb9

File tree

12 files changed

+651
-59
lines changed

12 files changed

+651
-59
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from pathlib import Path
2+
from typing import List, Optional, Union
3+
4+
from brainles_preprocessing.constants import Atlas, PreprocessorSteps
5+
from brainles_preprocessing.defacing import Defacer, QuickshearDefacer
6+
from brainles_preprocessing.modality import CenterModality, Modality
7+
from brainles_preprocessing.registration import ANTsRegistrator
8+
from brainles_preprocessing.registration.registrator import Registrator
9+
from brainles_preprocessing.utils.logging_utils import LoggingManager
10+
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
11+
12+
logging_man = LoggingManager(name=__name__)
13+
logger = logging_man.get_logger()
14+
15+
16+
class BackToNativeSpace:
17+
18+
def __init__(
19+
self,
20+
transformations_dir: Union[str, Path],
21+
registrator: Optional[Registrator] = None,
22+
):
23+
24+
self.transformations_dir = Path(transformations_dir)
25+
26+
if registrator is None:
27+
logger.warning(
28+
"No registrator provided, using default ANTsRegistrator for registration."
29+
)
30+
self.registrator: Registrator = registrator or ANTsRegistrator()
31+
32+
def transform(
33+
self,
34+
target_modality_name: str,
35+
target_modality_img: Union[str, Path],
36+
moving_image: Union[str, Path],
37+
output_img_path: Union[str, Path],
38+
log_file_path: Union[str, Path],
39+
interpolator: Optional[str] = None,
40+
):
41+
"""
42+
Apply inverse transformation to a moving image to align it with a target modality.
43+
44+
Args:
45+
target_modality_name (str): Name of the target modality. Must match the name used to create the transformations.
46+
target_modality_img (Union[str, Path]): Path to the target modality image.
47+
moving_image (Union[str, Path]): Path to the moving image. E.g., this could be a segmentation in atlas space.
48+
output_img_path (Union[str, Path]): Path where the transformed image will be saved.
49+
log_file_path (Union[str, Path]): Path to the log file where transformation details will be written.
50+
interpolator (Optional[str]): Interpolation method used during transformation.
51+
Available options depend on the chosen registrator:
52+
53+
- **ANTsRegistrator**:
54+
- "linear" (default)
55+
- "nearestNeighbor"
56+
- "multiLabel" (deprecated, prefer "genericLabel")
57+
- "gaussian"
58+
- "bSpline"
59+
- "cosineWindowedSinc"
60+
- "welchWindowedSinc"
61+
- "hammingWindowedSinc"
62+
- "lanczosWindowedSinc"
63+
- "genericLabel" (recommended for label images)
64+
65+
- **NiftyReg**:
66+
- "0": nearest neighbor
67+
- "1": linear (default)
68+
- "3": cubic spline
69+
- "4": sinc
70+
71+
Raises:
72+
AssertionError: If the transformations directory for the given modality does not exist.
73+
"""
74+
logger.info(
75+
f"Applying inverse transformation for {target_modality_name} using {self.registrator.__class__.__name__}."
76+
)
77+
78+
# assert modality name eixsts in transformations_dir
79+
modality_transformations_dir = (
80+
self.transformations_dir / f"{target_modality_name}"
81+
)
82+
83+
assert (
84+
modality_transformations_dir.exists()
85+
), f"Transformations directory for {target_modality_name} does not exist: {modality_transformations_dir}"
86+
87+
transforms = list(modality_transformations_dir.iterdir())
88+
transforms.sort() # sort by name to get order for forward transform
89+
transforms = transforms[::-1] # inverse order for inverse transform
90+
91+
kwargs = {
92+
"fixed_image_path": target_modality_img,
93+
"moving_image_path": moving_image,
94+
"transformed_image_path": output_img_path,
95+
"matrix_path": transforms,
96+
"log_file_path": str(log_file_path),
97+
}
98+
if interpolator is not None:
99+
kwargs["interpolator"] = interpolator
100+
self.registrator.inverse_transform(**kwargs)

brainles_preprocessing/modality.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
import shutil
33
import warnings
44
from pathlib import Path
5-
from typing import Optional, Union
5+
from typing import Dict, Optional, Union
66

77
from auxiliary.nifti.io import read_nifti, write_nifti
8+
89
from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
910
from brainles_preprocessing.constants import PreprocessorSteps
1011
from brainles_preprocessing.defacing import Defacer, QuickshearDefacer
1112
from brainles_preprocessing.normalization.normalizer_base import Normalizer
13+
from brainles_preprocessing.registration import ( # TODO: this will throw warnings if ANTs or NiftyReg are not installed, not ideal
14+
ANTsRegistrator,
15+
NiftyRegRegistrator,
16+
)
1217
from brainles_preprocessing.registration.registrator import Registrator
1318

1419
logger = logging.getLogger(__name__)
@@ -42,6 +47,7 @@ class Modality:
4247
normalized_defaced_output_path (str or Path, optional): Path to save the normalized defaced modality data. Requires a normalizer.
4348
bet (bool): Indicates whether brain extraction is enabled.
4449
atlas_correction (bool): Indicates whether atlas correction should be performed.
50+
coregistration_transform_path (str or None): Path to the coregistration transformation matrix, will be set after coregistration.
4551
4652
Example:
4753
>>> t1_modality = Modality(
@@ -73,6 +79,7 @@ def __init__(
7379
self.current = self.input_path
7480
self.normalizer = normalizer
7581
self.atlas_correction = atlas_correction
82+
self.transformation_paths: Dict[PreprocessorSteps, Path | None] = {}
7683

7784
# Check that atleast one output is generated
7885
if not any(
@@ -128,6 +135,7 @@ def __init__(
128135
self.normalized_defaced_output_path = None
129136

130137
self.steps = {k: None for k in PreprocessorSteps}
138+
self.steps[PreprocessorSteps.INPUT] = self.input_path
131139

132140
@property
133141
def bet(self) -> bool:
@@ -197,6 +205,26 @@ def normalize(
197205
else:
198206
logger.info("No normalizer specified; skipping normalization.")
199207

208+
def _find_transformation_matrix(
209+
self, transform_incomplete_path: Path
210+
) -> Optional[Path]:
211+
possible_Files = list(
212+
transform_incomplete_path.parent.glob(f"{transform_incomplete_path.stem}.*")
213+
)
214+
if len(possible_Files) == 0:
215+
logger.warning(
216+
f"No transformation matrix found for {transform_incomplete_path}. "
217+
"Returning None."
218+
)
219+
return None
220+
elif len(possible_Files) > 1:
221+
# TODO: Handle this case more gracefully, e.g., try to find proper extension based on the registrator
222+
logger.warning(
223+
f"Multiple transformation matrices found for {transform_incomplete_path}. "
224+
"Returning the first one."
225+
)
226+
return possible_Files[0]
227+
200228
def register(
201229
self,
202230
registrator: Registrator,
@@ -213,6 +241,7 @@ def register(
213241
fixed_image_path (str or Path): Path to the fixed image.
214242
registration_dir (str or Path): Directory to store registration results.
215243
moving_image_name (str): Name of the moving image.
244+
step (PreprocessorSteps): The current preprocessing step.
216245
217246
Returns:
218247
Path: Path to the registration matrix.
@@ -224,7 +253,7 @@ def register(
224253
registered_log = registration_dir / f"{moving_image_name}.log"
225254

226255
# Note, add file ending depending on registration backend!
227-
registered_matrix = registration_dir / f"{moving_image_name}"
256+
registered_matrix = registration_dir / f"M_{moving_image_name}"
228257

229258
registrator.register(
230259
fixed_image_path=fixed_image_path,
@@ -235,6 +264,11 @@ def register(
235264
)
236265
self.current = registered
237266
self.steps[step] = registered
267+
268+
self.transformation_paths[step] = self._find_transformation_matrix(
269+
transform_incomplete_path=registered_matrix
270+
)
271+
238272
return registered_matrix
239273

240274
def apply_bet_mask(
@@ -320,7 +354,7 @@ def transform(
320354
registration_dir_path (str or Path): Directory to store transformation results.
321355
moving_image_name (str): Name of the moving image.
322356
transformation_matrix_path (str or Path): Path to the transformation matrix.
323-
357+
step (PreprocessorSteps): The current preprocessing step.
324358
Returns:
325359
None
326360
"""
@@ -331,15 +365,42 @@ def transform(
331365
transformed = registration_dir_path / f"{moving_image_name}.nii.gz"
332366
transformed_log = registration_dir_path / f"{moving_image_name}.log"
333367

334-
registrator.transform(
335-
fixed_image_path=fixed_image_path,
336-
moving_image_path=self.current,
337-
transformed_image_path=transformed,
338-
matrix_path=transformation_matrix_path,
339-
log_file_path=transformed_log,
340-
)
368+
if (
369+
isinstance(registrator, (ANTsRegistrator, NiftyRegRegistrator))
370+
and step == PreprocessorSteps.ATLAS_REGISTERED
371+
):
372+
# we test uniting transforms for these registrators
373+
assert (
374+
self.transformation_paths.get(PreprocessorSteps.COREGISTERED, None)
375+
is not None
376+
), "Coregistration must be performed before applying atlas registration."
377+
378+
registrator.transform(
379+
fixed_image_path=fixed_image_path,
380+
moving_image_path=self.steps[PreprocessorSteps.INPUT],
381+
transformed_image_path=transformed,
382+
matrix_path=[
383+
self.transformation_paths[
384+
PreprocessorSteps.COREGISTERED
385+
], # coregistration matrix
386+
transformation_matrix_path, # atlas registration matrix
387+
],
388+
log_file_path=transformed_log,
389+
)
390+
else:
391+
registrator.transform(
392+
fixed_image_path=fixed_image_path,
393+
moving_image_path=self.current,
394+
transformed_image_path=transformed,
395+
matrix_path=transformation_matrix_path,
396+
log_file_path=transformed_log,
397+
)
398+
341399
self.current = transformed
342400
self.steps[step] = transformed
401+
self.transformation_paths[step] = self._find_transformation_matrix(
402+
transform_incomplete_path=transformation_matrix_path
403+
)
343404

344405
def extract_brain_region(
345406
self,

brainles_preprocessing/preprocessor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def run(
164164
save_dir_atlas_correction: Optional[Union[str, Path]] = None,
165165
save_dir_brain_extraction: Optional[Union[str, Path]] = None,
166166
save_dir_defacing: Optional[Union[str, Path]] = None,
167+
save_dir_transformations: Optional[Union[str, Path]] = None,
167168
log_file: Optional[Union[str, Path]] = None,
168169
):
169170
"""
@@ -176,6 +177,7 @@ def run(
176177
save_dir_atlas_correction (str or Path, optional): Directory path to save intermediate atlas correction results.
177178
save_dir_brain_extraction (str or Path, optional): Directory path to save intermediate brain extraction results.
178179
save_dir_defacing (str or Path, optional): Directory path to save intermediate defacing results.
180+
save_dir_transformations (str or Path, optional): Directory path to save transformation matrices. Defaults to None.
179181
log_file (str or Path, optional): Path to save the log file. Defaults to a timestamped file in the current directory.
180182
181183
This method orchestrates the entire preprocessing workflow by sequentially performing:
@@ -250,6 +252,28 @@ def run(
250252
save_dir_defacing=save_dir_defacing,
251253
)
252254

255+
# move to separate method
256+
if save_dir_transformations:
257+
save_dir_transformations = Path(save_dir_transformations)
258+
259+
# Save transformation matrices
260+
logger.info(f"Saving transformation matrices to {save_dir_transformations}")
261+
for modality in self.all_modalities:
262+
263+
modality_transformations_dir = (
264+
save_dir_transformations / modality.modality_name
265+
)
266+
modality_transformations_dir.mkdir(exist_ok=True, parents=True)
267+
for step, path in modality.transformation_paths.items():
268+
if path is not None:
269+
shutil.copyfile(
270+
src=str(path.absolute()),
271+
dst=str(
272+
modality_transformations_dir
273+
/ f"{step.value}_{path.name}"
274+
),
275+
)
276+
253277
# End
254278
logger.info(f"{' Preprocessing complete ':=^80}")
255279

0 commit comments

Comments
 (0)