Skip to content

Commit 7e9ad03

Browse files
author
Donglai Wei
committed
fix multi-task loss and label mapping
1 parent 1c96c78 commit 7e9ad03

File tree

11 files changed

+274
-91
lines changed

11 files changed

+274
-91
lines changed

connectomics/config/hydra_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ class ImageTransformConfig:
313313
clip_percentile_high: float = (
314314
1.0 # Upper percentile for clipping (1.0 = no clip, 0.95 = 95th percentile)
315315
)
316+
pad_size: Optional[List[int]] = None # Reflection padding for context [D, H, W] or [H, W]
317+
pad_mode: str = "reflect" # Padding mode: 'reflect', 'replicate', 'constant'
316318

317319

318320
@dataclass

connectomics/data/augment/build.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
CenterSpatialCropd,
2626
SpatialPadd,
2727
Resized,
28+
LoadImaged, # For filename-based datasets (PNG, JPG, etc.)
29+
EnsureChannelFirstd, # Ensure channel-first format for 2D/3D images
2830
)
2931

3032
# Import custom loader for HDF5/TIFF volumes
@@ -71,11 +73,20 @@ def build_train_transforms(
7173

7274
# Load images first (unless using pre-cached dataset)
7375
if not skip_loading:
74-
# Get transpose axes for training data
75-
train_transpose = cfg.data.train_transpose if cfg.data.train_transpose else []
76-
transforms.append(
77-
LoadVolumed(keys=keys, transpose_axes=train_transpose if train_transpose else None)
78-
)
76+
# Use appropriate loader based on dataset type
77+
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
78+
79+
if dataset_type == "filename":
80+
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
81+
transforms.append(LoadImaged(keys=keys, image_only=False))
82+
# Ensure channel-first format [C, H, W] or [C, D, H, W]
83+
transforms.append(EnsureChannelFirstd(keys=keys))
84+
else:
85+
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
86+
train_transpose = cfg.data.train_transpose if cfg.data.train_transpose else []
87+
transforms.append(
88+
LoadVolumed(keys=keys, transpose_axes=train_transpose if train_transpose else None)
89+
)
7990

8091
# Apply volumetric split if enabled
8192
if cfg.data.split_enabled:
@@ -212,12 +223,20 @@ def build_val_transforms(cfg: Config, keys: list[str] = None) -> Compose:
212223

213224
transforms = []
214225

215-
# Load images first
216-
# Get transpose axes for validation data
217-
val_transpose = cfg.data.val_transpose if cfg.data.val_transpose else []
218-
transforms.append(
219-
LoadVolumed(keys=keys, transpose_axes=val_transpose if val_transpose else None)
220-
)
226+
# Load images first - use appropriate loader based on dataset type
227+
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
228+
229+
if dataset_type == "filename":
230+
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
231+
transforms.append(LoadImaged(keys=keys, image_only=False))
232+
# Ensure channel-first format [C, H, W] or [C, D, H, W]
233+
transforms.append(EnsureChannelFirstd(keys=keys))
234+
else:
235+
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
236+
val_transpose = cfg.data.val_transpose if cfg.data.val_transpose else []
237+
transforms.append(
238+
LoadVolumed(keys=keys, transpose_axes=val_transpose if val_transpose else None)
239+
)
221240

222241
# Apply volumetric split if enabled
223242
if cfg.data.split_enabled:
@@ -342,20 +361,29 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose:
342361

343362
transforms = []
344363

345-
# Load images first
346-
# Get transpose axes for test data (check both data.test_transpose and inference.data.test_transpose)
347-
test_transpose = []
348-
if cfg.data.test_transpose:
349-
test_transpose = cfg.data.test_transpose
350-
if (
351-
hasattr(cfg, "inference")
352-
and hasattr(cfg.inference, "data")
353-
and hasattr(cfg.inference.data, "test_transpose")
354-
and cfg.inference.data.test_transpose
355-
):
356-
test_transpose = cfg.inference.data.test_transpose # inference takes precedence
357-
transforms.append(
358-
LoadVolumed(keys=keys, transpose_axes=test_transpose if test_transpose else None)
364+
# Load images first - use appropriate loader based on dataset type
365+
dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility
366+
367+
if dataset_type == "filename":
368+
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
369+
transforms.append(LoadImaged(keys=keys, image_only=False))
370+
# Ensure channel-first format [C, H, W] or [C, D, H, W]
371+
transforms.append(EnsureChannelFirstd(keys=keys))
372+
else:
373+
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
374+
# Get transpose axes for test data (check both data.test_transpose and inference.data.test_transpose)
375+
test_transpose = []
376+
if cfg.data.test_transpose:
377+
test_transpose = cfg.data.test_transpose
378+
if (
379+
hasattr(cfg, "inference")
380+
and hasattr(cfg.inference, "data")
381+
and hasattr(cfg.inference.data, "test_transpose")
382+
and cfg.inference.data.test_transpose
383+
):
384+
test_transpose = cfg.inference.data.test_transpose # inference takes precedence
385+
transforms.append(
386+
LoadVolumed(keys=keys, transpose_axes=test_transpose if test_transpose else None)
359387
)
360388

361389
# Apply volumetric split if enabled (though typically not used for test)

connectomics/data/dataset/dataset_volume_cached.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def __init__(
7979
iter_num: int = 500,
8080
transforms: Optional[Compose] = None,
8181
mode: str = "train",
82+
pad_size: Optional[Tuple[int, ...]] = None,
83+
pad_mode: str = "reflect",
8284
):
8385
self.image_paths = image_paths
8486
self.label_paths = label_paths if label_paths else [None] * len(image_paths)
@@ -97,6 +99,8 @@ def __init__(
9799
self.iter_num = iter_num if iter_num > 0 else len(image_paths)
98100
self.transforms = transforms
99101
self.mode = mode
102+
self.pad_size = pad_size
103+
self.pad_mode = pad_mode
100104

101105
# Load all volumes into memory
102106
print(f" Loading {len(image_paths)} volumes into memory...")
@@ -116,6 +120,11 @@ def __init__(
116120
img = img[None, ...] # Add channel for 2D
117121
elif img.ndim == 3:
118122
img = img[None, ...] # Add channel for 3D
123+
124+
# Apply padding if specified
125+
if self.pad_size is not None:
126+
img = self._apply_padding(img)
127+
119128
self.cached_images.append(img)
120129

121130
# Load label if available
@@ -126,6 +135,11 @@ def __init__(
126135
lbl = lbl[None, ...] # Add channel for 2D
127136
elif lbl.ndim == 3:
128137
lbl = lbl[None, ...] # Add channel for 3D
138+
139+
# Apply padding if specified (same padding as image)
140+
if self.pad_size is not None:
141+
lbl = self._apply_padding(lbl, mode='constant', constant_values=0) # Use constant 0 for labels
142+
129143
self.cached_labels.append(lbl)
130144
else:
131145
self.cached_labels.append(None)
@@ -135,6 +149,11 @@ def __init__(
135149
mask = read_volume(mask_path)
136150
if mask.ndim == 3:
137151
mask = mask[None, ...]
152+
153+
# Apply padding if specified (same padding as label)
154+
if self.pad_size is not None:
155+
mask = self._apply_padding(mask, mode='constant', constant_values=0)
156+
138157
self.cached_masks.append(mask)
139158
else:
140159
self.cached_masks.append(None)
@@ -148,6 +167,40 @@ def __init__(
148167
ndim = len(self.patch_size)
149168
self.volume_sizes = [img.shape[-ndim:] for img in self.cached_images] # (Z, Y, X) or (Y, X)
150169

170+
def _apply_padding(
171+
self, volume: np.ndarray, mode: Optional[str] = None, constant_values: float = 0
172+
) -> np.ndarray:
173+
"""
174+
Apply padding to a volume using np.pad.
175+
176+
Args:
177+
volume: Input volume with channel dimension (C, D, H, W) or (C, H, W)
178+
mode: Padding mode ('reflect', 'constant', etc.). If None, uses self.pad_mode
179+
constant_values: Value for constant padding
180+
181+
Returns:
182+
Padded volume
183+
"""
184+
if self.pad_size is None:
185+
return volume
186+
187+
mode = mode if mode is not None else self.pad_mode
188+
189+
# Build padding tuple for np.pad: ((before, after), ...)
190+
# For channel dimension: no padding (0, 0)
191+
# For spatial dimensions: pad according to pad_size
192+
pad_width = [(0, 0)] # No padding on channel dimension
193+
for p in self.pad_size:
194+
pad_width.append((p, p))
195+
196+
# Apply padding using np.pad
197+
if mode == 'constant':
198+
padded = np.pad(volume, pad_width, mode=mode, constant_values=constant_values)
199+
else:
200+
padded = np.pad(volume, pad_width, mode=mode)
201+
202+
return padded
203+
151204
def __len__(self) -> int:
152205
return self.iter_num
153206

connectomics/data/process/build.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,16 @@ def create_label_transform_pipeline(cfg: Any = None, **kwargs: Any) -> Compose:
123123
cfg = _coerce_config(cfg, kwargs)
124124

125125
# Keys configuration
126-
keys_attr = getattr(cfg, 'keys', None)
127-
if keys_attr is None:
128-
keys_option = [getattr(cfg, 'input_key', 'label')]
126+
# Note: Must check if 'keys' exists in config to avoid getting dict.keys() method
127+
if hasattr(cfg, '__dict__') and 'keys' in cfg.__dict__:
128+
keys_attr = cfg.keys
129+
elif hasattr(cfg, '__contains__') and 'keys' in cfg:
130+
keys_attr = cfg['keys'] if isinstance(cfg, dict) else getattr(cfg, 'keys')
131+
else:
132+
keys_attr = None
133+
134+
if keys_attr is None or callable(keys_attr): # Protect against dict.keys() method
135+
keys_option = [getattr(cfg, 'input_key', None) or cfg.get('input_key', 'label') if isinstance(cfg, dict) else 'label']
129136
elif isinstance(keys_attr, str):
130137
keys_option = [keys_attr]
131138
else:

connectomics/data/process/distance.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,20 +256,13 @@ def skeleton_aware_distance_transform(
256256
label = np.pad(label, pad_size, mode="constant", constant_values=0)
257257

258258
label_shape = label.shape
259-
all_bg_sample = False
260259

261260
skeleton = np.zeros(label_shape, dtype=np.uint8)
262261
distance = np.zeros(label_shape, dtype=np.float32)
263262

264263
indices = np.unique(label)
265-
if indices[0] == 0:
266-
if len(indices) > 1: # exclude background
267-
indices = indices[1:]
268-
else: # all-background sample
269-
all_bg_sample = True
270-
271-
if not all_bg_sample:
272-
for idx in indices:
264+
if len(indices) > 1:
265+
for idx in indices[indices > 0]:
273266
temp2 = remove_small_holes(label == idx, 16, connectivity=1)
274267
binary = temp2.copy()
275268

connectomics/data/process/monai_transforms.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -718,23 +718,56 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
718718
label = d[key]
719719
label_np, had_batch_dim = self._prepare_label(label)
720720

721-
# Remove channel dimension if it's 1 (target functions expect [D, H, W] not [1, D, H, W])
722-
if label_np.ndim == 4 and label_np.shape[0] == 1:
723-
label_np = label_np[0]
721+
# Determine if input is 2D or 3D based on original dimensions
722+
# After EnsureChannelFirstd: 2D images are [1, H, W], 3D volumes are [1, D, H, W]
723+
is_2d_input = label_np.ndim == 3 and label_np.shape[0] == 1
724+
is_3d_input = label_np.ndim == 4 and label_np.shape[0] == 1
725+
726+
# Remove channel dimension (target functions don't expect it)
727+
if is_3d_input:
728+
label_np = label_np[0] # [1, D, H, W] -> [D, H, W]
729+
elif is_2d_input:
730+
# For 2D, keep as [1, H, W] since some functions (boundary, edt) expect 3D input
731+
# even in 2D mode (they treat first dim as Z=1)
732+
pass # Keep [1, H, W]
724733

725734
outputs: List[np.ndarray] = []
726735
for spec in self.task_specs:
727-
result = spec["fn"](label_np, **spec["kwargs"])
736+
try:
737+
result = spec["fn"](label_np, **spec["kwargs"])
738+
except Exception as e:
739+
raise RuntimeError(
740+
f"Task '{spec['name']}' failed with error: {e}\n"
741+
f"Label shape: {label_np.shape}, dtype: {label_np.dtype}\n"
742+
f"Task kwargs: {spec['kwargs']}"
743+
) from e
728744
if result is None:
729-
raise RuntimeError(f"Task '{spec['name']}' returned None.")
745+
raise RuntimeError(
746+
f"Task '{spec['name']}' returned None.\n"
747+
f"Label shape: {label_np.shape}, dtype: {label_np.dtype}\n"
748+
f"Task kwargs: {spec['kwargs']}"
749+
)
730750
result_arr = np.asarray(
731751
result, dtype=np.float32
732752
) # Convert to float32 (handles bool->float)
733753

734-
# Ensure each output has a channel dimension [C, D, H, W]
735-
# If output is [D, H, W], expand to [1, D, H, W]
736-
if result_arr.ndim == 3:
737-
result_arr = result_arr[np.newaxis, ...] # Add channel dimension
754+
# Normalize output dimensions:
755+
# For 2D images (input [1, H, W]): functions return [H, W] or [1, H, W]
756+
# For 3D volumes (input [D, H, W]): functions return [D, H, W]
757+
# Goal: Add channel dimension to get [1, H, W] for 2D or [1, D, H, W] for 3D
758+
759+
if is_2d_input:
760+
# 2D case: some functions return [H, W], others return [1, H, W]
761+
if result_arr.ndim == 3 and result_arr.shape[0] == 1:
762+
# Function returned [1, H, W], squeeze Z dimension
763+
result_arr = result_arr[0] # [1, H, W] -> [H, W]
764+
# Now result_arr is [H, W], add channel dimension
765+
if result_arr.ndim == 2:
766+
result_arr = result_arr[np.newaxis, ...] # [H, W] -> [1, H, W]
767+
elif is_3d_input:
768+
# 3D case: functions return [D, H, W], add channel dimension
769+
if result_arr.ndim == 3:
770+
result_arr = result_arr[np.newaxis, ...] # [D, H, W] -> [1, D, H, W]
738771

739772
outputs.append(result_arr)
740773

connectomics/lightning/callbacks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ def on_train_epoch_end(self, trainer, pl_module):
129129

130130
print(f"✓ Saved visualization for epoch {trainer.current_epoch}")
131131
except Exception as e:
132+
import traceback
132133
print(f"Epoch-end visualization failed: {e}")
134+
print(f"Error type: {type(e).__name__}")
135+
if hasattr(e, '__traceback__'):
136+
print("Traceback:")
137+
traceback.print_exception(type(e), e, e.__traceback__)
133138

134139
def on_validation_epoch_end(self, trainer, pl_module):
135140
"""Visualize at end of validation epoch based on log_every_n_epochs."""
@@ -172,7 +177,12 @@ def on_validation_epoch_end(self, trainer, pl_module):
172177
prefix='val' # Single tab name (no epoch prefix)
173178
)
174179
except Exception as e:
180+
import traceback
175181
print(f"Validation epoch-end visualization failed: {e}")
182+
print(f"Error type: {type(e).__name__}")
183+
if hasattr(e, '__traceback__'):
184+
print("Traceback:")
185+
traceback.print_exception(type(e), e, e.__traceback__)
176186

177187

178188
class NaNDetectionCallback(Callback):

0 commit comments

Comments
 (0)