Skip to content

Commit 5a719b4

Browse files
authored
Merge pull request #162 from BoyuShen2004/master
feat: Add 2D data processing support with do_2d flag
2 parents 1c687f4 + 9c9a6b7 commit 5a719b4

File tree

6 files changed

+70
-24
lines changed

6 files changed

+70
-24
lines changed

connectomics/config/hydra_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,14 @@ class DataConfig:
329329
- Multi-channel label transformations
330330
- Train/validation splitting options
331331
- Caching and performance optimization
332+
- 2D data support with do_2d parameter
332333
"""
333334

334335
# Dataset type
335336
dataset_type: Optional[str] = None # Type of dataset: None (volume), 'filename', 'tile', etc.
337+
338+
# 2D data support
339+
do_2d: bool = False # Enable 2D data processing (extract 2D slices from 3D volumes)
336340

337341
# Base path (prepended to train_image, train_label, etc. if set)
338342
train_path: str = "" # Base path for training data (e.g., "/path/to/dataset/")
@@ -765,6 +769,9 @@ class InferenceDataConfig:
765769
default_factory=list
766770
) # Axis permutation for test data (e.g., [2,1,0] for xyz->zyx)
767771
output_path: str = "results/"
772+
773+
# 2D data support
774+
do_2d: bool = False # Enable 2D data processing for inference
768775

769776

770777
@dataclass

connectomics/data/dataset/dataset_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def __init__(
7070
super().__init__(data=data_dicts, transform=transforms)
7171

7272
# Store connectomics-specific parameters
73-
self.sample_size = ensure_tuple_rep(sample_size, 3)
73+
# For 2D data, use 2D dimensions; otherwise use 3D
74+
if do_2d:
75+
self.sample_size = ensure_tuple_rep(sample_size, 2)
76+
else:
77+
self.sample_size = ensure_tuple_rep(sample_size, 3)
7478
self.mode = mode
7579
self.iter_num = iter_num
7680
self.valid_ratio = valid_ratio

connectomics/lightning/lit_model.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ def _setup_sliding_window_inferer(self):
156156
if not hasattr(self.cfg, 'inference'):
157157
return
158158

159+
# For 2D models with do_2d=True, disable sliding window inference
160+
if getattr(self.cfg.data, 'do_2d', False):
161+
warnings.warn(
162+
"Sliding-window inference disabled for 2D models with do_2d=True. "
163+
"Using direct inference instead.",
164+
UserWarning,
165+
)
166+
return
167+
159168
roi_size = self._resolve_inferer_roi_size()
160169
if roi_size is None:
161170
warnings.warn(
@@ -200,12 +209,20 @@ def _resolve_inferer_roi_size(self) -> Optional[Tuple[int, ...]]:
200209
if hasattr(self.cfg, 'model') and hasattr(self.cfg.model, 'output_size'):
201210
output_size = getattr(self.cfg.model, 'output_size', None)
202211
if output_size:
203-
return tuple(int(v) for v in output_size)
212+
roi_size = tuple(int(v) for v in output_size)
213+
# For 2D models with do_2d=True, convert to 3D ROI size
214+
if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2:
215+
roi_size = (1,) + roi_size # Add depth dimension
216+
return roi_size
204217

205218
if hasattr(self.cfg, 'data') and hasattr(self.cfg.data, 'patch_size'):
206219
patch_size = getattr(self.cfg.data, 'patch_size', None)
207220
if patch_size:
208-
return tuple(int(v) for v in patch_size)
221+
roi_size = tuple(int(v) for v in patch_size)
222+
# For 2D models with do_2d=True, convert to 3D ROI size
223+
if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2:
224+
roi_size = (1,) + roi_size # Add depth dimension
225+
return roi_size
209226

210227
return None
211228

@@ -367,6 +384,10 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
367384
f"Expected shapes: (D, H, W), (B, D, H, W), or (B, C, D, H, W)"
368385
)
369386

387+
# For 2D models with do_2d=True, squeeze the depth dimension if present
388+
if getattr(self.cfg.data, 'do_2d', False) and images.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W]
389+
images = images.squeeze(2)
390+
370391
# Get TTA configuration (default to no augmentation if not configured)
371392
if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'test_time_augmentation'):
372393
tta_flip_axes_config = getattr(self.cfg.inference.test_time_augmentation, 'flip_axes', None)
@@ -385,23 +406,21 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
385406
ensemble_result = self._apply_tta_preprocessing(pred)
386407
else:
387408
if tta_flip_axes_config == 'all' or tta_flip_axes_config == []:
388-
# "all" or []: All 8 flips (all combinations of Z, Y, X)
389-
# IMPORTANT: MONAI Flip spatial_axis behavior for (B, C, D, H, W) tensors:
390-
# spatial_axis=[0] flips C (channel) - WRONG for TTA!
391-
# spatial_axis=[1] flips D (depth/Z) - CORRECT
392-
# spatial_axis=[2] flips H (height/Y) - CORRECT
393-
# spatial_axis=[3] flips W (width/X) - CORRECT
394-
# Must use [1, 2, 3] for [D, H, W] flips, NOT [0, 1, 2]!
395-
tta_flip_axes = [
396-
[], # No flip
397-
[1], # Flip Z (depth)
398-
[2], # Flip Y (height)
399-
[3], # Flip X (width)
400-
[1, 2], # Flip Z+Y
401-
[1, 3], # Flip Z+X
402-
[2, 3], # Flip Y+X
403-
[1, 2, 3], # Flip Z+Y+X
404-
]
409+
# "all" or []: All flips (all combinations of spatial axes)
410+
# Determine spatial axes based on data dimensions
411+
if images.dim() == 5: # 3D data: [B, C, D, H, W]
412+
spatial_axes = [1, 2, 3] # [D, H, W]
413+
elif images.dim() == 4: # 2D data: [B, C, H, W]
414+
spatial_axes = [1, 2] # [H, W]
415+
else:
416+
raise ValueError(f"Unsupported data dimensions: {images.dim()}")
417+
418+
# Generate all combinations of spatial axes
419+
tta_flip_axes = [[]] # No flip baseline
420+
for r in range(1, len(spatial_axes) + 1):
421+
from itertools import combinations
422+
for combo in combinations(spatial_axes, r):
423+
tta_flip_axes.append(list(combo))
405424
elif isinstance(tta_flip_axes_config, (list, tuple)):
406425
# Custom list: Add no-flip baseline + user-specified flips
407426
tta_flip_axes = [[]] + list(tta_flip_axes_config)

connectomics/models/arch/monai_models.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,18 @@ def __init__(self, model: nn.Module):
3636

3737
def forward(self, x: torch.Tensor) -> torch.Tensor:
3838
"""Forward pass through MONAI model."""
39-
return self.model(x)
39+
# For 2D models, squeeze the depth dimension if present
40+
if x.dim() == 5 and x.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W]
41+
x = x.squeeze(2)
42+
43+
# Forward through model
44+
output = self.model(x)
45+
46+
# For 2D models, add back the depth dimension if needed for sliding window inference
47+
if output.dim() == 4 and x.dim() == 5: # [B, C, H, W] -> [B, C, 1, H, W]
48+
output = output.unsqueeze(2)
49+
50+
return output
4051

4152

4253
def _check_monai_available():

scripts/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def setup(self, stage=None):
821821
cache_rate=cfg.data.cache_rate if use_cache else 0.0,
822822
iter_num=iter_num_for_dataset,
823823
sample_size=tuple(cfg.data.patch_size),
824+
do_2d=cfg.data.do_2d,
824825
)
825826
# Setup datasets based on mode
826827
if mode == "train":

tutorials/monai2d_worm.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ model:
5858

5959
# Data - Using automatic 80/20 train/val split (DeepEM-style)
6060
data:
61+
# 2D data support
62+
do_2d: true # Enable 2D data processing (extract 2D slices from 3D volumes)
63+
6164
# Volume configuration
6265
train_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/*.tif
6366
train_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/*.tif
@@ -152,8 +155,9 @@ monitor:
152155
# Inference - MONAI SlidingWindowInferer
153156
inference:
154157
data:
155-
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTr/Image96_00002_0000.tif
156-
test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/labelsTr/Image96_00002.tif
158+
do_2d: true # Enable 2D data processing for inference
159+
test_image: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
160+
test_label: /projects/weilab/shenb/PyTC/datasets/Dataset001_worm_image96/imagesTs/*.tif
157161
test_resolution: [5, 5]
158162
output_path: outputs/monai2d_worm/results/
159163

@@ -188,7 +192,7 @@ inference:
188192
# Evaluation
189193
evaluation:
190194
enabled: true # Use eval mode for BatchNorm
191-
metrics: [jaccard] # Metrics to compute
195+
metrics: [adapted_rand] # Metrics to compute (adapted_rand for instance segmentation)
192196

193197
# NOTE: batch_size=1 for inference
194198
# During training: batch_size controls how many random patches to load

0 commit comments

Comments
 (0)