66
77from __future__ import annotations
88from typing import Dict
9- import numpy as np
109import torch
1110from monai .transforms import (
1211 Compose ,
4140 RandCutBlurd ,
4241 RandMixupd ,
4342 RandCopyPasted ,
44- SqueezeLabeld ,
4543 NormalizeLabelsd ,
4644 SmartNormalizeIntensityd ,
4745)
48- from ...config .hydra_config import Config , AugmentationConfig , LabelTransformConfig
46+ from ...config .hydra_config import Config , AugmentationConfig
4947
5048
5149def build_train_transforms (
@@ -522,23 +520,59 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
522520 List of MONAI transforms
523521 """
524522 transforms = []
525-
523+
524+ # Get preset mode (default to "some" for backward compatibility)
525+ preset = getattr (aug_cfg , "preset" , "some" )
526+
527+ # Helper function to check if augmentation should be applied
528+ def should_augment (aug_name : str , aug_enabled : bool ) -> bool :
529+ """
530+ Determine if augmentation should be applied based on preset mode.
531+
532+ - "all": enabled = True by default, set to False to disable
533+ - "some": enabled = False by default, set to True to enable
534+ - "none": always False (no augmentations)
535+ """
536+ if preset == "none" :
537+ return False
538+ elif preset == "all" :
539+ # All enabled by default, respect False overrides
540+ return aug_enabled
541+ else : # preset == "some"
542+ # None enabled by default, only use True values
543+ return aug_enabled
544+
526545 # Standard geometric augmentations
527- if aug_cfg .flip .enabled :
546+ if should_augment ( "flip" , aug_cfg .flip .enabled ) :
528547 transforms .append (
529548 RandFlipd (keys = keys , prob = aug_cfg .flip .prob , spatial_axis = aug_cfg .flip .spatial_axis )
530549 )
531550
532- if aug_cfg .rotate .enabled :
551+ if should_augment ("rotate" , aug_cfg .rotate .enabled ):
552+ # Use spatial_axes from config if available
553+ spatial_axes = getattr (aug_cfg .rotate , "spatial_axes" , (1 , 2 ))
533554 transforms .append (
534555 RandRotate90d (
535556 keys = keys ,
536557 prob = aug_cfg .rotate .prob ,
537- spatial_axes = (1 , 2 ), # Rotate in Y-X plane to preserve anisotropic Z
558+ spatial_axes = spatial_axes , # Rotate in specified plane
559+ )
560+ )
561+
562+ if should_augment ("affine" , aug_cfg .affine .enabled ):
563+ transforms .append (
564+ RandAffined (
565+ keys = keys ,
566+ prob = aug_cfg .affine .prob ,
567+ rotate_range = aug_cfg .affine .rotate_range ,
568+ scale_range = aug_cfg .affine .scale_range ,
569+ shear_range = aug_cfg .affine .shear_range ,
570+ mode = "bilinear" ,
571+ padding_mode = "reflection" ,
538572 )
539573 )
540574
541- if aug_cfg .elastic .enabled :
575+ if should_augment ( "elastic" , aug_cfg .elastic .enabled ) :
542576 transforms .append (
543577 Rand3DElasticd (
544578 keys = keys ,
@@ -549,7 +583,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
549583 )
550584
551585 # Intensity augmentations (only for images)
552- if aug_cfg .intensity .enabled :
586+ if should_augment ( "intensity" , aug_cfg .intensity .enabled ) :
553587 if aug_cfg .intensity .gaussian_noise_prob > 0 :
554588 transforms .append (
555589 RandGaussianNoised (
@@ -578,7 +612,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
578612 )
579613
580614 # EM-specific augmentations
581- if aug_cfg .misalignment .enabled :
615+ if should_augment ( "misalignment" , aug_cfg .misalignment .enabled ) :
582616 transforms .append (
583617 RandMisAlignmentd (
584618 keys = keys ,
@@ -588,7 +622,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
588622 )
589623 )
590624
591- if aug_cfg .missing_section .enabled :
625+ if should_augment ( "missing_section" , aug_cfg .missing_section .enabled ) :
592626 transforms .append (
593627 RandMissingSectiond (
594628 keys = keys ,
@@ -597,7 +631,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
597631 )
598632 )
599633
600- if aug_cfg .motion_blur .enabled :
634+ if should_augment ( "motion_blur" , aug_cfg .motion_blur .enabled ) :
601635 transforms .append (
602636 RandMotionBlurd (
603637 keys = ["image" ],
@@ -607,7 +641,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
607641 )
608642 )
609643
610- if aug_cfg .cut_noise .enabled :
644+ if should_augment ( "cut_noise" , aug_cfg .cut_noise .enabled ) :
611645 transforms .append (
612646 RandCutNoised (
613647 keys = ["image" ],
@@ -617,7 +651,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
617651 )
618652 )
619653
620- if aug_cfg .cut_blur .enabled :
654+ if should_augment ( "cut_blur" , aug_cfg .cut_blur .enabled ) :
621655 transforms .append (
622656 RandCutBlurd (
623657 keys = ["image" ],
@@ -628,7 +662,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
628662 )
629663 )
630664
631- if aug_cfg .missing_parts .enabled :
665+ if should_augment ( "missing_parts" , aug_cfg .missing_parts .enabled ) :
632666 transforms .append (
633667 RandMissingPartsd (
634668 keys = keys ,
@@ -638,14 +672,14 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
638672 )
639673
640674 # Advanced augmentations
641- if aug_cfg .mixup .enabled :
675+ if should_augment ( "mixup" , aug_cfg .mixup .enabled ) :
642676 transforms .append (
643677 RandMixupd (
644678 keys = ["image" ], prob = aug_cfg .mixup .prob , alpha_range = aug_cfg .mixup .alpha_range
645679 )
646680 )
647681
648- if aug_cfg .copy_paste .enabled :
682+ if should_augment ( "copy_paste" , aug_cfg .copy_paste .enabled ) :
649683 transforms .append (
650684 RandCopyPasted (
651685 keys = ["image" ],
0 commit comments