Skip to content

Commit ebf22e1

Browse files
authored
Merge pull request #139 from BrainLesion/121-feature-non-atlas-centric-preprocessing-pipeline
121 feature non atlas centric preprocessing pipeline
2 parents 157bfbe + b1ce517 commit ebf22e1

File tree

12 files changed

+898
-539
lines changed

12 files changed

+898
-539
lines changed

brainles_preprocessing/defacing/quickshear/quickshear.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from brainles_preprocessing.defacing.defacer import Defacer
88
from brainles_preprocessing.defacing.quickshear.nipy_quickshear import run_quickshear
9+
from brainles_preprocessing.constants import Atlas
910

1011

1112
class QuickshearDefacer(Defacer):
@@ -29,14 +30,23 @@ class QuickshearDefacer(Defacer):
2930
```
3031
"""
3132

32-
def __init__(self, buffer: float = 10.0):
33+
def __init__(
34+
self,
35+
buffer: float = 10.0,
36+
force_atlas_registration: bool = True,
37+
atlas_image_path: Union[str, Path, Atlas] = Atlas.SRI24,
38+
):
3339
"""Initialize Quickshear defacer
3440
3541
Args:
3642
buffer (float, optional): buffer parameter from quickshear algorithm. Defaults to 10.0.
43+
force_atlas_registration (bool, optional): If True, forces atlas registration of the BET mask before defacing to potentially boost quickshear performance. Defaults to True.
44+
atlas_image_path (Union[str, Path, Atlas], optional): Path to the atlas image or an Atlas enum value that will be used for the optional atlas registrations. Defaults to Atlas.SRI24.
3745
"""
3846
super().__init__()
3947
self.buffer = buffer
48+
self.force_atlas_registration = force_atlas_registration
49+
self.atlas_image_path = atlas_image_path
4050

4151
def deface(
4252
self,

brainles_preprocessing/modality.py

Lines changed: 89 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from auxiliary.io import read_image, write_image
88

99
from brainles_preprocessing.brain_extraction.brain_extractor import BrainExtractor
10-
from brainles_preprocessing.constants import PreprocessorSteps
10+
from brainles_preprocessing.constants import Atlas, PreprocessorSteps
1111
from brainles_preprocessing.defacing import Defacer, QuickshearDefacer
1212
from brainles_preprocessing.normalization.normalizer_base import Normalizer
1313
from brainles_preprocessing.registration import ( # TODO: this will throw warnings if ANTs or NiftyReg are not installed, not ideal
1414
ANTsRegistrator,
1515
NiftyRegRegistrator,
1616
)
1717
from brainles_preprocessing.registration.registrator import Registrator
18+
from brainles_preprocessing.utils.zenodo import verify_or_download_atlases
1819

1920
logger = logging.getLogger(__name__)
2021

@@ -295,7 +296,7 @@ def apply_bet_mask(
295296
if self.bet:
296297
mask_path = Path(mask_path)
297298
bet_dir = Path(bet_dir)
298-
bet_img = bet_dir / f"atlas__{self.modality_name}_bet.nii.gz"
299+
bet_img = bet_dir / f"{self.modality_name}_bet.nii.gz"
299300

300301
brain_extractor.apply_mask(
301302
input_image_path=self.current,
@@ -324,14 +325,27 @@ def apply_deface_mask(
324325
if self.requires_deface:
325326
mask_path = Path(mask_path)
326327
deface_dir = Path(deface_dir)
327-
defaced_img = deface_dir / f"atlas__{self.modality_name}_defaced.nii.gz"
328-
input_img = self.steps[
329-
(
330-
PreprocessorSteps.ATLAS_CORRECTED
331-
if self.atlas_correction
332-
else PreprocessorSteps.ATLAS_REGISTERED
328+
defaced_img = deface_dir / f"{self.modality_name}_defaced.nii.gz"
329+
330+
# For Atlas centric preprocessing, we use the atlas corrected or registered image as input
331+
# For Native space preprocessing, we use the coregistered image as input
332+
input_img = (
333+
self.steps[
334+
(
335+
PreprocessorSteps.ATLAS_CORRECTED
336+
if self.atlas_correction
337+
else PreprocessorSteps.ATLAS_REGISTERED
338+
)
339+
]
340+
or self.steps[PreprocessorSteps.COREGISTERED]
341+
)
342+
343+
if input_img is None:
344+
raise ValueError(
345+
"Input image for defacing is missing. Ensure that the required preprocessing steps "
346+
"have been performed before defacing."
333347
)
334-
]
348+
335349
defacer.apply_mask(
336350
input_image_path=input_img,
337351
mask_path=mask_path,
@@ -428,65 +442,40 @@ def extract_brain_region(
428442
"Legacy method. Please Migrate to use the CenterModality Class. Will be removed in future versions.",
429443
category=DeprecationWarning,
430444
)
431-
432445
bet_dir_path = Path(bet_dir_path)
433446
bet_log = bet_dir_path / "brain-extraction.log"
434447

435-
atlas_bet_cm = bet_dir_path / f"atlas__{self.modality_name}_bet.nii.gz"
436-
mask_path = bet_dir_path / f"atlas__{self.modality_name}_brain_mask.nii.gz"
448+
bet = bet_dir_path / f"{self.modality_name}_bet.nii.gz"
449+
mask_path = bet_dir_path / f"{self.modality_name}_brain_mask.nii.gz"
437450

438451
brain_extractor.extract(
439452
input_image_path=self.current,
440-
masked_image_path=atlas_bet_cm,
453+
masked_image_path=bet,
441454
brain_mask_path=mask_path,
442455
log_file_path=bet_log,
443456
)
444457

445458
# always temporarily store bet image for center modality, since e.g. quickshear defacing could require it
446459
# down the line even if the user does not wish to save the bet image
447-
self.steps[PreprocessorSteps.BET] = atlas_bet_cm
460+
self.steps[PreprocessorSteps.BET] = bet
448461

449462
if self.bet:
450-
self.current = atlas_bet_cm
463+
self.current = bet
451464
return mask_path
452465

453466
def deface(
454467
self,
455468
defacer,
456469
defaced_dir_path: Union[str, Path],
457-
) -> Path:
470+
registrator: Optional[Registrator] = None,
471+
) -> Path | None:
458472
"""
459473
WARNING: Legacy method. Please Migrate to use the CenterModality Class. Will be removed in future versions.
460-
461-
Deface the current modality using the specified defacer.
462-
463-
Args:
464-
defacer (Defacer): The defacer object.
465-
defaced_dir_path (str or Path): Directory to store defacing results.
466-
467-
Returns:
468-
Path: Path to the extracted brain mask.
469474
"""
470-
warnings.warn(
471-
"Legacy method. Please Migrate to use the CenterModality class. Will be removed in future versions.",
472-
category=DeprecationWarning,
475+
raise RuntimeError(
476+
"The 'deface' method has been deprecated and moved to the CenterModality class as its only supposed to be called once from the CenterModality. "
477+
"Please update your code to use the 'CenterModality.deface()' method instead."
473478
)
474-
if isinstance(defacer, QuickshearDefacer):
475-
defaced_dir_path = Path(defaced_dir_path)
476-
atlas_mask_path = (
477-
defaced_dir_path / f"atlas__{self.modality_name}_deface_mask.nii.gz"
478-
)
479-
480-
defacer.deface(
481-
mask_image_path=atlas_mask_path,
482-
input_image_path=self.steps[PreprocessorSteps.BET],
483-
)
484-
return atlas_mask_path
485-
else:
486-
logger.warning(
487-
"Defacing method not implemented yet. Skipping defacing for this modality."
488-
)
489-
return None
490479

491480
def save_current_image(
492481
self,
@@ -612,6 +601,7 @@ def extract_brain_region(
612601
bet_dir_path: Union[str, Path],
613602
) -> Path:
614603
"""
604+
615605
Extract the brain region using the specified brain extractor.
616606
617607
Args:
@@ -621,66 +611,95 @@ def extract_brain_region(
621611
Returns:
622612
Path: Path to the extracted brain mask.
623613
"""
614+
624615
bet_dir_path = Path(bet_dir_path)
625616
bet_log = bet_dir_path / "brain-extraction.log"
626617

627-
atlas_bet_cm = bet_dir_path / f"atlas__{self.modality_name}_bet.nii.gz"
628-
mask_path = bet_dir_path / f"atlas__{self.modality_name}_brain_mask.nii.gz"
618+
bet = bet_dir_path / f"{self.modality_name}_bet.nii.gz"
619+
mask_path = bet_dir_path / f"{self.modality_name}_brain_mask.nii.gz"
629620

630621
brain_extractor.extract(
631622
input_image_path=self.current,
632-
masked_image_path=atlas_bet_cm,
623+
masked_image_path=bet,
633624
brain_mask_path=mask_path,
634625
log_file_path=bet_log,
635626
)
636627

637-
if self.bet_mask_output_path:
638-
logger.debug(f"Saving bet mask to {self.bet_mask_output_path}")
639-
self.save_mask(mask_path=mask_path, output_path=self.bet_mask_output_path)
640-
641628
# always temporarily store bet image for center modality, since e.g. quickshear defacing could require it
642629
# down the line even if the user does not wish to save the bet image
643-
self.steps[PreprocessorSteps.BET] = atlas_bet_cm
630+
self.steps[PreprocessorSteps.BET] = bet
644631

645632
if self.bet:
646-
self.current = atlas_bet_cm
633+
self.current = bet
647634
return mask_path
648635

649636
def deface(
650637
self,
651-
defacer,
638+
defacer: Defacer,
652639
defaced_dir_path: Union[str, Path],
653-
) -> Path:
640+
registrator: Optional[Registrator] = None,
641+
) -> Path | None:
654642
"""
655643
Deface the current modality using the specified defacer.
656644
657645
Args:
658646
defacer (Defacer): The defacer object.
659647
defaced_dir_path (str or Path): Directory to store defacing results.
648+
registrator (Registrator, optional): The registrator object for atlas registration.
660649
661650
Returns:
662-
Path: Path to the extracted brain mask.
651+
Path | None: Path to the defacing mask if successful, None otherwise.
663652
"""
664-
665653
if isinstance(defacer, QuickshearDefacer):
666654
defaced_dir_path = Path(defaced_dir_path)
667-
atlas_mask_path = (
668-
defaced_dir_path / f"atlas__{self.modality_name}_deface_mask.nii.gz"
669-
)
655+
mask_path = defaced_dir_path / f"{self.modality_name}_deface_mask.nii.gz"
670656

671-
defacer.deface(
672-
mask_image_path=atlas_mask_path,
673-
input_image_path=self.steps[PreprocessorSteps.BET],
674-
)
657+
if self.steps.get(PreprocessorSteps.BET, None) is None:
658+
raise ValueError(
659+
"Brain extraction must be performed before defacing. "
660+
"Please run brain extraction first."
661+
)
675662

676-
if self.defacing_mask_output_path:
677-
logger.debug(f"Saving deface mask to {self.defacing_mask_output_path}")
678-
self.save_mask(
679-
mask_path=atlas_mask_path,
680-
output_path=self.defacing_mask_output_path,
663+
if defacer.force_atlas_registration and registrator is not None:
664+
logger.info("Forcing atlas registration before defacing as requested.")
665+
atlas_bet = defaced_dir_path / "atlas_bet.nii.gz"
666+
atlas_bet_M = defaced_dir_path / "M_atlas_bet"
667+
668+
# resolve atlas image path
669+
if isinstance(defacer.atlas_image_path, Atlas):
670+
atlas_folder = verify_or_download_atlases()
671+
atlas_image_path = atlas_folder / defacer.atlas_image_path.value
672+
else:
673+
atlas_image_path = Path(defacer.atlas_image_path)
674+
675+
registrator.register(
676+
fixed_image_path=atlas_image_path,
677+
moving_image_path=self.steps[PreprocessorSteps.BET],
678+
transformed_image_path=atlas_bet,
679+
matrix_path=atlas_bet_M,
680+
log_file_path=defaced_dir_path / "atlas_bet.log",
681+
)
682+
683+
deface_mask_atlas = defaced_dir_path / "deface_mask_atlas.nii.gz"
684+
defacer.deface(
685+
input_image_path=atlas_bet,
686+
mask_image_path=deface_mask_atlas,
687+
)
688+
689+
registrator.inverse_transform(
690+
fixed_image_path=self.steps[PreprocessorSteps.BET],
691+
moving_image_path=deface_mask_atlas,
692+
transformed_image_path=mask_path,
693+
matrix_path=atlas_bet_M,
694+
log_file_path=defaced_dir_path / "inverse_transform.log",
695+
)
696+
else:
697+
defacer.deface(
698+
input_image_path=self.steps[PreprocessorSteps.BET],
699+
mask_image_path=mask_path,
681700
)
682701

683-
return atlas_mask_path
702+
return mask_path
684703
else:
685704
logger.warning(
686705
"Defacing method not implemented yet. Skipping defacing for this modality."
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import warnings
2+
from .atlas_centric_preprocessor import AtlasCentricPreprocessor
3+
from .native_space_preprocessor import NativeSpacePreprocessor
4+
5+
6+
# Deprecation warning for Preprocessor alias, added to ensure backward compatibility.
7+
class Preprocessor(AtlasCentricPreprocessor):
8+
def __init__(self, *args, **kwargs):
9+
warnings.warn(
10+
"Preprocessor has been renamed to AtlasCentricPreprocessor and is deprecated."
11+
"The alias will be removed in future releases, please migrate to AtlasCentricPreprocessor.",
12+
DeprecationWarning,
13+
stacklevel=2,
14+
)
15+
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)