Skip to content

Commit af7b29f

Browse files
author
Donglai Wei
committed
update augmentation strategy
1 parent 7e9ad03 commit af7b29f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+769
-432
lines changed

connectomics/config/auto_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
Users can manually override any auto-determined parameters.
1212
"""
1313

14-
import torch
1514
import numpy as np
16-
from typing import Dict, List, Tuple, Optional, Any
15+
from typing import Dict, List, Optional, Any
1716
from dataclasses import dataclass, field
1817
from omegaconf import OmegaConf, DictConfig
1918
import warnings

connectomics/config/gpu_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import torch
99
import psutil
10-
from typing import Dict, Optional, Tuple
11-
import warnings
10+
from typing import Dict, Tuple
1211

1312

1413
def get_gpu_info() -> Dict[str, any]:

connectomics/config/hydra_config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,20 +614,32 @@ class MonitorConfig:
614614
# Augmentation configurations
615615
@dataclass
616616
class FlipConfig:
617-
"""Random flip augmentation."""
617+
"""Random flipping augmentation."""
618618

619619
enabled: bool = True
620620
prob: float = 0.5
621621
spatial_axis: Optional[List[int]] = None # None = all axes
622622

623623

624+
@dataclass
625+
class AffineConfig:
626+
"""Affine transformation augmentation (rotation, scaling, shearing)."""
627+
628+
enabled: bool = False # Disabled by default (can be combined with Rotate90d)
629+
prob: float = 0.5
630+
rotate_range: Tuple[float, float, float] = (0.2, 0.2, 0.2) # Rotation range in radians (~11°)
631+
scale_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Scaling range (±10%)
632+
shear_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Shearing range (±10%)
633+
634+
624635
@dataclass
625636
class RotateConfig:
626637
"""Random rotation augmentation."""
627638

628639
enabled: bool = True
629640
prob: float = 0.5
630641
max_angle: float = 90.0
642+
spatial_axes: Tuple[int, int] = (1, 2) # Axes to rotate: (1, 2) = Y-X plane (preserves Z)
631643

632644

633645
@dataclass
@@ -736,10 +748,12 @@ class CopyPasteConfig:
736748
class AugmentationConfig:
737749
"""Complete augmentation configuration."""
738750

751+
preset: str = "some" # "all", "some", or "none" - controls how enabled flags are interpreted
739752
enabled: bool = False
740753

741754
# Standard augmentations
742755
flip: FlipConfig = field(default_factory=FlipConfig)
756+
affine: AffineConfig = field(default_factory=AffineConfig) # Added AffineConfig
743757
rotate: RotateConfig = field(default_factory=RotateConfig)
744758
elastic: ElasticConfig = field(default_factory=ElasticConfig)
745759
intensity: IntensityConfig = field(default_factory=IntensityConfig)
@@ -1022,6 +1036,7 @@ def configure_instance_segmentation(cfg: Config, boundary_thickness: int = 5) ->
10221036
# Augmentation configuration
10231037
"AugmentationConfig",
10241038
"FlipConfig",
1039+
"AffineConfig", # Added AffineConfig
10251040
"RotateConfig",
10261041
"ElasticConfig",
10271042
"IntensityConfig",

connectomics/data/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
from connectomics.data.process import MultiTaskLabelTransformd, create_label_transform_pipeline
1515
"""
1616

17-
from .dataset.dataset_base import *
18-
from .dataset import *
19-
from .io import *
2017

2118
# Make submodules available
2219
from . import augment

connectomics/data/augment/build.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from __future__ import annotations
88
from typing import Dict
9-
import numpy as np
109
import torch
1110
from monai.transforms import (
1211
Compose,
@@ -41,11 +40,10 @@
4140
RandCutBlurd,
4241
RandMixupd,
4342
RandCopyPasted,
44-
SqueezeLabeld,
4543
NormalizeLabelsd,
4644
SmartNormalizeIntensityd,
4745
)
48-
from ...config.hydra_config import Config, AugmentationConfig, LabelTransformConfig
46+
from ...config.hydra_config import Config, AugmentationConfig
4947

5048

5149
def build_train_transforms(
@@ -522,23 +520,59 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
522520
List of MONAI transforms
523521
"""
524522
transforms = []
525-
523+
524+
# Get preset mode (default to "some" for backward compatibility)
525+
preset = getattr(aug_cfg, "preset", "some")
526+
527+
# Helper function to check if augmentation should be applied
528+
def should_augment(aug_name: str, aug_enabled: bool) -> bool:
529+
"""
530+
Determine if augmentation should be applied based on preset mode.
531+
532+
- "all": enabled = True by default, set to False to disable
533+
- "some": enabled = False by default, set to True to enable
534+
- "none": always False (no augmentations)
535+
"""
536+
if preset == "none":
537+
return False
538+
elif preset == "all":
539+
# All enabled by default, respect False overrides
540+
return aug_enabled
541+
else: # preset == "some"
542+
# None enabled by default, only use True values
543+
return aug_enabled
544+
526545
# Standard geometric augmentations
527-
if aug_cfg.flip.enabled:
546+
if should_augment("flip", aug_cfg.flip.enabled):
528547
transforms.append(
529548
RandFlipd(keys=keys, prob=aug_cfg.flip.prob, spatial_axis=aug_cfg.flip.spatial_axis)
530549
)
531550

532-
if aug_cfg.rotate.enabled:
551+
if should_augment("rotate", aug_cfg.rotate.enabled):
552+
# Use spatial_axes from config if available
553+
spatial_axes = getattr(aug_cfg.rotate, "spatial_axes", (1, 2))
533554
transforms.append(
534555
RandRotate90d(
535556
keys=keys,
536557
prob=aug_cfg.rotate.prob,
537-
spatial_axes=(1, 2), # Rotate in Y-X plane to preserve anisotropic Z
558+
spatial_axes=spatial_axes, # Rotate in specified plane
559+
)
560+
)
561+
562+
if should_augment("affine", aug_cfg.affine.enabled):
563+
transforms.append(
564+
RandAffined(
565+
keys=keys,
566+
prob=aug_cfg.affine.prob,
567+
rotate_range=aug_cfg.affine.rotate_range,
568+
scale_range=aug_cfg.affine.scale_range,
569+
shear_range=aug_cfg.affine.shear_range,
570+
mode="bilinear",
571+
padding_mode="reflection",
538572
)
539573
)
540574

541-
if aug_cfg.elastic.enabled:
575+
if should_augment("elastic", aug_cfg.elastic.enabled):
542576
transforms.append(
543577
Rand3DElasticd(
544578
keys=keys,
@@ -549,7 +583,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
549583
)
550584

551585
# Intensity augmentations (only for images)
552-
if aug_cfg.intensity.enabled:
586+
if should_augment("intensity", aug_cfg.intensity.enabled):
553587
if aug_cfg.intensity.gaussian_noise_prob > 0:
554588
transforms.append(
555589
RandGaussianNoised(
@@ -578,7 +612,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
578612
)
579613

580614
# EM-specific augmentations
581-
if aug_cfg.misalignment.enabled:
615+
if should_augment("misalignment", aug_cfg.misalignment.enabled):
582616
transforms.append(
583617
RandMisAlignmentd(
584618
keys=keys,
@@ -588,7 +622,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
588622
)
589623
)
590624

591-
if aug_cfg.missing_section.enabled:
625+
if should_augment("missing_section", aug_cfg.missing_section.enabled):
592626
transforms.append(
593627
RandMissingSectiond(
594628
keys=keys,
@@ -597,7 +631,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
597631
)
598632
)
599633

600-
if aug_cfg.motion_blur.enabled:
634+
if should_augment("motion_blur", aug_cfg.motion_blur.enabled):
601635
transforms.append(
602636
RandMotionBlurd(
603637
keys=["image"],
@@ -607,7 +641,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
607641
)
608642
)
609643

610-
if aug_cfg.cut_noise.enabled:
644+
if should_augment("cut_noise", aug_cfg.cut_noise.enabled):
611645
transforms.append(
612646
RandCutNoised(
613647
keys=["image"],
@@ -617,7 +651,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
617651
)
618652
)
619653

620-
if aug_cfg.cut_blur.enabled:
654+
if should_augment("cut_blur", aug_cfg.cut_blur.enabled):
621655
transforms.append(
622656
RandCutBlurd(
623657
keys=["image"],
@@ -628,7 +662,7 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
628662
)
629663
)
630664

631-
if aug_cfg.missing_parts.enabled:
665+
if should_augment("missing_parts", aug_cfg.missing_parts.enabled):
632666
transforms.append(
633667
RandMissingPartsd(
634668
keys=keys,
@@ -638,14 +672,14 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str]) -> list:
638672
)
639673

640674
# Advanced augmentations
641-
if aug_cfg.mixup.enabled:
675+
if should_augment("mixup", aug_cfg.mixup.enabled):
642676
transforms.append(
643677
RandMixupd(
644678
keys=["image"], prob=aug_cfg.mixup.prob, alpha_range=aug_cfg.mixup.alpha_range
645679
)
646680
)
647681

648-
if aug_cfg.copy_paste.enabled:
682+
if should_augment("copy_paste", aug_cfg.copy_paste.enabled):
649683
transforms.append(
650684
RandCopyPasted(
651685
keys=["image"],

connectomics/data/augment/monai_transforms.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import cv2
1414
from monai.config import KeysCollection
1515
from monai.transforms import MapTransform, RandomizableTransform
16-
from monai.utils import ensure_tuple_rep
1716

1817

1918
class RandMisAlignmentd(RandomizableTransform, MapTransform):
@@ -807,7 +806,6 @@ def _find_best_paste(
807806
label_flipped: torch.Tensor,
808807
) -> Tuple[torch.Tensor, torch.Tensor]:
809808
"""Find best rotation and position with minimal overlap."""
810-
import torchvision.transforms.functional as tf
811809
from scipy.ndimage import binary_dilation
812810

813811
labels = torch.stack([label_orig, label_flipped])

connectomics/data/dataset/dataset_base.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,17 @@
77
"""
88

99
from __future__ import annotations
10-
from typing import Dict, List, Any, Optional, Union, Callable, Sequence, Tuple
11-
from abc import ABC, abstractmethod
12-
import os
10+
from typing import Dict, Any, Optional, Sequence, Tuple
1311
import numpy as np
1412

1513
import torch
1614
import torch.utils.data
17-
import pytorch_lightning as pl
18-
from torch.utils.data import DataLoader
1915

2016
# MONAI imports
2117
from monai.data import Dataset, CacheDataset, PersistentDataset
2218
from monai.transforms import Compose
2319
from monai.utils import ensure_tuple_rep
2420

25-
from ..io import read_volume
2621

2722

2823
class MonaiConnectomicsDataset(Dataset):

connectomics/data/dataset/dataset_filename.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,16 @@
77
"""
88

99
from __future__ import annotations
10-
from typing import Dict, List, Any, Optional, Union, Sequence, Tuple
10+
from typing import Dict, Any, Optional, Tuple
1111
import json
1212
import random
1313
from pathlib import Path
1414
import warnings
1515

16-
import numpy as np
1716
import torch
1817
from monai.data import Dataset
19-
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd
20-
from monai.utils import ensure_tuple_rep
18+
from monai.transforms import Compose
2119

22-
from .dataset_base import MonaiConnectomicsDataset
2320

2421

2522
class MonaiFilenameDataset(Dataset):

connectomics/data/dataset/dataset_tile.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
"""
77

88
from __future__ import annotations
9-
from typing import Dict, List, Any, Optional, Union, Sequence, Tuple
10-
import numpy as np
9+
from typing import Dict, List, Any, Optional, Tuple
1110
import json
12-
import random
1311

14-
import torch
15-
from monai.data import Dataset, CacheDataset
16-
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd
12+
from monai.data import CacheDataset
13+
from monai.transforms import Compose, EnsureChannelFirstd
1714
from monai.utils import ensure_tuple_rep
1815

1916
from .dataset_base import MonaiConnectomicsDataset
20-
from ..io import create_tile_metadata, reconstruct_volume_from_tiles, TileLoaderd
17+
from ..io import TileLoaderd
2118

2219

2320
class MonaiTileDataset(MonaiConnectomicsDataset):

connectomics/data/dataset/dataset_volume.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,13 @@
66
"""
77

88
from __future__ import annotations
9-
from typing import Dict, List, Any, Optional, Union, Callable, Sequence, Tuple
10-
import numpy as np
11-
import random
12-
import warnings
13-
14-
import torch
15-
from monai.data import Dataset, CacheDataset
16-
from monai.transforms import Compose, RandSpatialCropd, CenterSpatialCropd, LoadImaged, EnsureChannelFirstd
9+
from typing import List, Optional, Tuple
10+
11+
from monai.data import CacheDataset
12+
from monai.transforms import Compose, RandSpatialCropd, CenterSpatialCropd
1713
from monai.utils import ensure_tuple_rep
1814

1915
from .dataset_base import MonaiConnectomicsDataset
20-
from monai.config import KeysCollection
21-
from monai.transforms import MapTransform
2216
from ..io.monai_transforms import LoadVolumed
2317

2418

@@ -210,8 +204,8 @@ def __init__(
210204
sample_size = kwargs.get('sample_size', (32, 256, 256))
211205
mode = kwargs.get('mode', 'train')
212206
do_2d = kwargs.get('do_2d', False)
213-
data_mean = kwargs.get('data_mean', 0.5)
214-
data_std = kwargs.get('data_std', 0.5)
207+
kwargs.get('data_mean', 0.5)
208+
kwargs.get('data_std', 0.5)
215209
transpose_axes = kwargs.get('transpose_axes', None)
216210

217211
# Create data dictionaries

0 commit comments

Comments
 (0)