Skip to content

Commit 1c687f4

Browse files
author
Donglai Wei
committed
fix multi-class loss and sdt
1 parent 180d265 commit 1c687f4

File tree

12 files changed

+775
-106
lines changed

12 files changed

+775
-106
lines changed

connectomics/config/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
validate_config,
1616
get_config_hash,
1717
create_experiment_name,
18+
resolve_data_paths,
1819
)
1920

2021
# Auto-configuration system
@@ -47,6 +48,7 @@
4748
'validate_config',
4849
'get_config_hash',
4950
'create_experiment_name',
51+
'resolve_data_paths',
5052
# Auto-configuration
5153
'auto_plan_config',
5254
'AutoConfigPlanner',

connectomics/config/hydra_config.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -334,16 +334,23 @@ class DataConfig:
334334
# Dataset type
335335
dataset_type: Optional[str] = None # Type of dataset: None (volume), 'filename', 'tile', etc.
336336

337+
# Base path (prepended to train_image, train_label, etc. if set)
338+
train_path: str = "" # Base path for training data (e.g., "/path/to/dataset/")
339+
val_path: str = "" # Base path for validation data
340+
test_path: str = "" # Base path for test data
341+
337342
# Paths - Volume-based datasets
338-
train_image: Optional[str] = None
339-
train_label: Optional[str] = None
340-
train_mask: Optional[str] = None # Valid region mask for training
341-
val_image: Optional[str] = None
342-
val_label: Optional[str] = None
343-
val_mask: Optional[str] = None # Valid region mask for validation
344-
test_image: Optional[str] = None
345-
test_label: Optional[str] = None
346-
test_mask: Optional[str] = None # Valid region mask for testing
343+
# These can be strings (single file), lists (multiple files), or None
344+
# Using Any to support both str and List[str] (OmegaConf doesn't support Union of containers)
345+
train_image: Any = None # str, List[str], or None
346+
train_label: Any = None # str, List[str], or None
347+
train_mask: Any = None # str, List[str], or None (Valid region mask for training)
348+
val_image: Any = None # str, List[str], or None
349+
val_label: Any = None # str, List[str], or None
350+
val_mask: Any = None # str, List[str], or None (Valid region mask for validation)
351+
test_image: Any = None # str, List[str], or None
352+
test_label: Any = None # str, List[str], or None
353+
test_mask: Any = None # str, List[str], or None (Valid region mask for testing)
347354

348355
# Paths - JSON/filename-based datasets
349356
train_json: Optional[str] = None # JSON file with image/label file lists
@@ -413,6 +420,9 @@ class DataConfig:
413420
True # Preload volumes into memory for fast random cropping (default: True)
414421
)
415422

423+
# Reject sampling configuration (for volumetric patch sampling)
424+
reject_sampling: Optional[Dict[str, Any]] = None # Dict with 'size_thres' and 'p' keys
425+
416426
# Multi-channel label transformation (for affinity maps, distance transforms, etc.)
417427
label_transform: LabelTransformConfig = field(default_factory=LabelTransformConfig)
418428

@@ -745,9 +755,9 @@ class AugmentationConfig:
745755
class InferenceDataConfig:
746756
"""Inference data configuration."""
747757

748-
test_image: Optional[str] = None # Singular form for compatibility
749-
test_label: Optional[str] = None # Singular form for compatibility
750-
test_mask: Optional[str] = None # Optional mask for inference
758+
test_image: Any = None # str, List[str], or None - Can be single file or list of files
759+
test_label: Any = None # str, List[str], or None - Can be single file or list of files
760+
test_mask: Any = None # str, List[str], or None - Optional mask for inference
751761
test_resolution: Optional[List[float]] = (
752762
None # Test data resolution [z, y, x] in nm (e.g., [30, 6, 6])
753763
)
@@ -763,12 +773,14 @@ class SlidingWindowConfig:
763773

764774
window_size: Optional[List[int]] = None
765775
sw_batch_size: Optional[int] = None # If None, will use system.inference.batch_size
766-
overlap: float = 0.5
776+
overlap: Optional[float] = 0.5 # Overlap ratio (0-1), or None to use stride instead
777+
stride: Optional[List[int]] = None # Explicit stride (overrides overlap if set)
767778
blending: str = "gaussian" # 'gaussian' or 'constant' - blending mode for overlapping patches
768779
sigma_scale: float = (
769780
0.125 # Gaussian sigma scale (only for blending='gaussian'); larger = smoother blending
770781
)
771782
padding_mode: str = "constant" # Padding mode at volume boundaries
783+
pad_size: Optional[List[int]] = None # Padding size for context (e.g., [16, 32, 32])
772784

773785

774786
@dataclass

connectomics/config/hydra_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,95 @@ def create_experiment_name(cfg: Config) -> str:
231231
return "_".join(parts)
232232

233233

234+
def resolve_data_paths(cfg: Config) -> Config:
235+
"""
236+
Resolve data paths by combining base paths (train_path, val_path, test_path)
237+
with relative file paths (train_image, train_label, etc.).
238+
239+
This function modifies the config in-place by:
240+
1. Prepending base paths to relative file paths
241+
2. Expanding glob patterns to actual file lists
242+
3. Flattening nested lists from glob expansion
243+
244+
Args:
245+
cfg: Config object to resolve paths for
246+
247+
Returns:
248+
Config object with resolved paths (same object, modified in-place)
249+
250+
Example:
251+
>>> cfg.data.train_path = "/data/barcode/"
252+
>>> cfg.data.train_image = ["PT37/*_raw.tif", "file.tif"]
253+
>>> resolve_data_paths(cfg)
254+
>>> print(cfg.data.train_image)
255+
['/data/barcode/PT37/img1_raw.tif', '/data/barcode/PT37/img2_raw.tif', '/data/barcode/file.tif']
256+
"""
257+
import os
258+
from glob import glob
259+
260+
def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
261+
"""Helper to combine base path with file path(s) and expand globs."""
262+
if file_path is None:
263+
return file_path
264+
265+
# Handle list of paths
266+
if isinstance(file_path, list):
267+
result = []
268+
for p in file_path:
269+
resolved = _combine_path(base_path, p)
270+
# If resolved is a list (from glob expansion), extend
271+
if isinstance(resolved, list):
272+
result.extend(resolved)
273+
else:
274+
result.append(resolved)
275+
return result
276+
277+
# Handle string path
278+
# Combine with base path if relative
279+
if base_path and not os.path.isabs(file_path):
280+
file_path = os.path.join(base_path, file_path)
281+
282+
# Expand glob patterns
283+
if "*" in file_path or "?" in file_path:
284+
expanded = sorted(glob(file_path))
285+
if expanded:
286+
return expanded
287+
else:
288+
# No matches - return original pattern (will be caught by validation)
289+
return file_path
290+
291+
return file_path
292+
293+
# Resolve training paths
294+
if cfg.data.train_path:
295+
cfg.data.train_image = _combine_path(cfg.data.train_path, cfg.data.train_image)
296+
cfg.data.train_label = _combine_path(cfg.data.train_path, cfg.data.train_label)
297+
cfg.data.train_mask = _combine_path(cfg.data.train_path, cfg.data.train_mask)
298+
cfg.data.train_json = _combine_path(cfg.data.train_path, cfg.data.train_json)
299+
300+
# Resolve validation paths
301+
if cfg.data.val_path:
302+
cfg.data.val_image = _combine_path(cfg.data.val_path, cfg.data.val_image)
303+
cfg.data.val_label = _combine_path(cfg.data.val_path, cfg.data.val_label)
304+
cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask)
305+
cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json)
306+
307+
# Resolve test paths
308+
if cfg.data.test_path:
309+
cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image)
310+
cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label)
311+
cfg.data.test_mask = _combine_path(cfg.data.test_path, cfg.data.test_mask)
312+
cfg.data.test_json = _combine_path(cfg.data.test_path, cfg.data.test_json)
313+
314+
# Also resolve inference data paths
315+
if cfg.data.test_path and cfg.inference.data:
316+
cfg.inference.data.test_image = _combine_path(cfg.data.test_path, cfg.inference.data.test_image)
317+
cfg.inference.data.test_label = _combine_path(cfg.data.test_path, cfg.inference.data.test_label)
318+
cfg.inference.data.test_mask = _combine_path(cfg.data.test_path, cfg.inference.data.test_mask)
319+
320+
return cfg
321+
322+
234323
__all__ = [
235324
"load_config",
236325
"save_config",
@@ -242,4 +331,5 @@ def create_experiment_name(cfg: Config) -> str:
242331
"validate_config",
243332
"get_config_hash",
244333
"create_experiment_name",
334+
"resolve_data_paths",
245335
]

connectomics/data/process/distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def skeleton_aware_distance_transform(
232232
bg_value: float = -1.0,
233233
relabel: bool = True,
234234
padding: bool = False,
235-
resolution: Tuple[float] = (1.0, 1.0),
235+
resolution: Tuple[float] = (1.0, 1.0, 1.0),
236236
alpha: float = 0.8,
237237
smooth: bool = True,
238238
smooth_skeleton_only: bool = True,

connectomics/lightning/lit_model.py

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,33 +1005,72 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_
10051005
ds_weights = [1.0] + [0.5 ** i for i in range(1, len(ds_outputs) + 1)]
10061006
all_outputs = [main_output] + ds_outputs
10071007

1008+
# Check if multi-task learning is configured
1009+
is_multi_task = hasattr(self.cfg.model, 'multi_task_config') and self.cfg.model.multi_task_config is not None
1010+
10081011
for scale_idx, (output, ds_weight) in enumerate(zip(all_outputs, ds_weights)):
10091012
# Match target to output size
10101013
target = self._match_target_to_output(labels, output)
10111014

10121015
# Compute loss for this scale
10131016
scale_loss = 0.0
1014-
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
1015-
loss = loss_fn(output, target)
10161017

1017-
# Check for NaN/Inf immediately after computing loss
1018-
if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)):
1019-
print(f"\n{'='*80}")
1020-
print(f"⚠️ NaN/Inf detected in loss computation!")
1021-
print(f"{'='*80}")
1022-
print(f"Loss function: {loss_fn.__class__.__name__}")
1023-
print(f"Loss value: {loss.item()}")
1024-
print(f"Scale: {scale_idx}, Weight: {weight}")
1025-
print(f"Output shape: {output.shape}, range: [{output.min():.4f}, {output.max():.4f}]")
1026-
print(f"Target shape: {target.shape}, range: [{target.min():.4f}, {target.max():.4f}]")
1027-
print(f"Output contains NaN: {torch.isnan(output).any()}")
1028-
print(f"Target contains NaN: {torch.isnan(target).any()}")
1029-
if self.debug_on_nan:
1030-
print(f"\nEntering debugger...")
1031-
pdb.set_trace()
1032-
raise ValueError(f"NaN/Inf in loss at scale {scale_idx}")
1033-
1034-
scale_loss += loss * weight
1018+
if is_multi_task:
1019+
# Multi-task learning with deep supervision:
1020+
# Apply specific losses to specific channels at each scale
1021+
for task_idx, task_config in enumerate(self.cfg.model.multi_task_config):
1022+
start_ch, end_ch, task_name, loss_indices = task_config
1023+
1024+
# Extract channels for this task
1025+
task_output = output[:, start_ch:end_ch, ...]
1026+
task_target = target[:, start_ch:end_ch, ...]
1027+
1028+
# Apply specified losses for this task
1029+
for loss_idx in loss_indices:
1030+
loss_fn = self.loss_functions[loss_idx]
1031+
weight = self.loss_weights[loss_idx]
1032+
1033+
loss = loss_fn(task_output, task_target)
1034+
1035+
# Check for NaN/Inf
1036+
if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)):
1037+
print(f"\n{'='*80}")
1038+
print(f"⚠️ NaN/Inf detected in deep supervision multi-task loss!")
1039+
print(f"{'='*80}")
1040+
print(f"Scale: {scale_idx}, Task: {task_name} (channels {start_ch}:{end_ch})")
1041+
print(f"Loss function: {loss_fn.__class__.__name__} (index {loss_idx})")
1042+
print(f"Loss value: {loss.item()}")
1043+
print(f"Output shape: {task_output.shape}, range: [{task_output.min():.4f}, {task_output.max():.4f}]")
1044+
print(f"Target shape: {task_target.shape}, range: [{task_target.min():.4f}, {task_target.max():.4f}]")
1045+
if self.debug_on_nan:
1046+
print(f"\nEntering debugger...")
1047+
pdb.set_trace()
1048+
raise ValueError(f"NaN/Inf in deep supervision loss at scale {scale_idx}, task {task_name}")
1049+
1050+
scale_loss += loss * weight
1051+
else:
1052+
# Standard deep supervision: apply all losses to all outputs
1053+
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
1054+
loss = loss_fn(output, target)
1055+
1056+
# Check for NaN/Inf immediately after computing loss
1057+
if self.enable_nan_detection and (torch.isnan(loss) or torch.isinf(loss)):
1058+
print(f"\n{'='*80}")
1059+
print(f"⚠️ NaN/Inf detected in loss computation!")
1060+
print(f"{'='*80}")
1061+
print(f"Loss function: {loss_fn.__class__.__name__}")
1062+
print(f"Loss value: {loss.item()}")
1063+
print(f"Scale: {scale_idx}, Weight: {weight}")
1064+
print(f"Output shape: {output.shape}, range: [{output.min():.4f}, {output.max():.4f}]")
1065+
print(f"Target shape: {target.shape}, range: [{target.min():.4f}, {target.max():.4f}]")
1066+
print(f"Output contains NaN: {torch.isnan(output).any()}")
1067+
print(f"Target contains NaN: {torch.isnan(target).any()}")
1068+
if self.debug_on_nan:
1069+
print(f"\nEntering debugger...")
1070+
pdb.set_trace()
1071+
raise ValueError(f"NaN/Inf in loss at scale {scale_idx}")
1072+
1073+
scale_loss += loss * weight
10351074

10361075
total_loss += scale_loss * ds_weight
10371076
loss_dict[f'train_loss_scale_{scale_idx}'] = scale_loss.item()
@@ -1100,15 +1139,38 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
11001139
ds_weights = [1.0] + [0.5 ** i for i in range(1, len(ds_outputs) + 1)]
11011140
all_outputs = [main_output] + ds_outputs
11021141

1142+
# Check if multi-task learning is configured
1143+
is_multi_task = hasattr(self.cfg.model, 'multi_task_config') and self.cfg.model.multi_task_config is not None
1144+
11031145
for scale_idx, (output, ds_weight) in enumerate(zip(all_outputs, ds_weights)):
11041146
# Match target to output size
11051147
target = self._match_target_to_output(labels, output)
11061148

11071149
# Compute loss for this scale
11081150
scale_loss = 0.0
1109-
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
1110-
loss = loss_fn(output, target)
1111-
scale_loss += loss * weight
1151+
1152+
if is_multi_task:
1153+
# Multi-task learning with deep supervision:
1154+
# Apply specific losses to specific channels at each scale
1155+
for task_idx, task_config in enumerate(self.cfg.model.multi_task_config):
1156+
start_ch, end_ch, task_name, loss_indices = task_config
1157+
1158+
# Extract channels for this task
1159+
task_output = output[:, start_ch:end_ch, ...]
1160+
task_target = target[:, start_ch:end_ch, ...]
1161+
1162+
# Apply specified losses for this task
1163+
for loss_idx in loss_indices:
1164+
loss_fn = self.loss_functions[loss_idx]
1165+
weight = self.loss_weights[loss_idx]
1166+
1167+
loss = loss_fn(task_output, task_target)
1168+
scale_loss += loss * weight
1169+
else:
1170+
# Standard deep supervision: apply all losses to all outputs
1171+
for loss_fn, weight in zip(self.loss_functions, self.loss_weights):
1172+
loss = loss_fn(output, target)
1173+
scale_loss += loss * weight
11121174

11131175
total_loss += scale_loss * ds_weight
11141176
loss_dict[f'val_loss_scale_{scale_idx}'] = scale_loss.item()
@@ -1367,6 +1429,10 @@ def _match_target_to_output(
13671429
For segmentation masks, uses nearest-neighbor interpolation to preserve labels.
13681430
For continuous targets, uses trilinear interpolation.
13691431
1432+
IMPORTANT: For continuous targets in range [-1, 1] (e.g., tanh-normalized SDT),
1433+
trilinear interpolation can cause overshooting beyond bounds. We clamp the
1434+
resized targets back to [-1, 1] to prevent loss explosion.
1435+
13701436
Args:
13711437
target: Target tensor of shape (B, C, D, H, W)
13721438
output: Output tensor of shape (B, C, D', H', W')
@@ -1396,6 +1462,18 @@ def _match_target_to_output(
13961462
align_corners=False,
13971463
)
13981464

1465+
# CRITICAL FIX: Clamp resized targets to prevent interpolation overshooting
1466+
# For targets in range [-1, 1] (e.g., tanh-normalized SDT), trilinear interpolation
1467+
# can produce values outside this range (e.g., -1.2, 1.3) which causes loss explosion
1468+
# when used with tanh-activated predictions.
1469+
# Check if targets are in typical normalized ranges:
1470+
if target.min() >= -1.5 and target.max() <= 1.5:
1471+
# Likely normalized to [-1, 1] (with some tolerance for existing overshoots)
1472+
target_resized = torch.clamp(target_resized, -1.0, 1.0)
1473+
elif target.min() >= 0.0 and target.max() <= 1.5:
1474+
# Likely normalized to [0, 1]
1475+
target_resized = torch.clamp(target_resized, 0.0, 1.0)
1476+
13991477
return target_resized
14001478

14011479
def configure_optimizers(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)