diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md index 47fdade0..8df0bbd0 100644 --- a/REFACTORING_PLAN.md +++ b/REFACTORING_PLAN.md @@ -259,6 +259,18 @@ connectomics/lightning/ └── lit_data.py # LightningDataModule (684 lines, existing) ``` +**Migration Steps:** +1. Create new module files +2. Move functionality in logical chunks +3. Update imports in `lit_model.py` +4. Add integration tests for each module +5. Update documentation + +**Success Criteria:** +- [ ] Each file < 500 lines +- [ ] Clear separation of concerns +- [ ] All existing tests pass +- [ ] Documentation updated **Note:** Multi-task learning was integrated into `deep_supervision.py` (not a separate module) since the logic is tightly coupled with deep supervision. **Completed Actions:** @@ -408,7 +420,6 @@ class DataConfig: - [x] `deep_supervision_clamp_min: float` (default: -20.0) - [x] `deep_supervision_clamp_max: float` (default: 20.0) - [x] Validation logic with warning for insufficient weights -- [x] Backward compatible (defaults match old behavior) - [ ] Other hardcoded values (target interpolation, rejection sampling) - Future work **Status:** ✅ Phase 2.3 (Deep Supervision) completed. Users can now customize deep supervision weights and clamping ranges via config. @@ -500,14 +511,12 @@ def build_test_transforms(cfg: Config, keys: list[str] = None) -> Compose: - File size reduced from 791 to 727 lines (-64 lines, ~8% reduction) - Eliminated ~80% code duplication - Single source of truth for shared transform logic -- Backward compatible (same public API) **Action Items:** - [x] Extract shared logic into `_build_eval_transforms_impl()` - [x] Identify val/test-specific differences (4 key differences) - [x] Create mode-specific branching with clear comments - [x] Keep wrapper functions for API compatibility -- [x] Backward compatible (public API unchanged) **Status:** ✅ Phase 2.5 complete. Code duplication eliminated while preserving all functionality. @@ -996,10 +1005,8 @@ See Priority 1.3 above for full details. ### Mitigation Strategies 1. **Comprehensive testing** before and after each change -2. **Feature flags** for backward compatibility -3. **Deprecation warnings** before removal -4. **Rollback plan** for each phase -5. **User communication** via release notes +2. **Rollback plan** for each phase +3. **User communication** via release notes --- diff --git a/connectomics/config/hydra_config.py b/connectomics/config/hydra_config.py index d49220ad..0e76e9cb 100644 --- a/connectomics/config/hydra_config.py +++ b/connectomics/config/hydra_config.py @@ -392,9 +392,6 @@ class DataConfig: default_factory=list ) # Axis permutation for training data (e.g., [2,1,0] for xyz->zyx) val_transpose: List[int] = field(default_factory=list) # Axis permutation for validation data - test_transpose: List[int] = field( - default_factory=list - ) # Axis permutation for test data (deprecated, use inference.data.test_transpose) # Dataset statistics (for auto-planning) target_spacing: Optional[List[float]] = None # Target voxel spacing [z, y, x] in mm @@ -868,9 +865,6 @@ class TestTimeAugmentationConfig: flip_axes: Any = ( None # TTA flip strategy: "all" (8 flips), null (no aug), or list like [[0], [1], [2]] ) - act: Optional[str] = ( - None # Single activation for all channels: 'softmax', 'sigmoid', 'tanh', None (deprecated, use channel_activations) - ) channel_activations: Optional[List[Any]] = ( None # Per-channel activations: [[start_ch, end_ch, 'activation'], ...] e.g., [[0, 2, 'softmax'], [2, 3, 'sigmoid'], [3, 4, 'tanh']] ) diff --git a/connectomics/config/hydra_utils.py b/connectomics/config/hydra_utils.py index c546d3ab..df9bdb82 100644 --- a/connectomics/config/hydra_utils.py +++ b/connectomics/config/hydra_utils.py @@ -244,8 +244,8 @@ def resolve_data_paths(cfg: Config) -> Config: Supported paths: - Training: cfg.data.train_path + cfg.data.train_image/train_label/train_mask - Validation: cfg.data.val_path + cfg.data.val_image/val_label/val_mask - - Testing (legacy): cfg.data.test_path + cfg.data.test_image/test_label/test_mask - - Inference (primary): cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask + - Testing: cfg.data.test_path + cfg.data.test_image/test_label/test_mask + - Inference: cfg.inference.data.test_path + cfg.inference.data.test_image/test_label/test_mask Args: cfg: Config object to resolve paths for @@ -316,7 +316,7 @@ def _combine_path(base_path: str, file_path: Optional[Union[str, List[str]]]) -> cfg.data.val_mask = _combine_path(cfg.data.val_path, cfg.data.val_mask) cfg.data.val_json = _combine_path(cfg.data.val_path, cfg.data.val_json) - # Resolve test paths (legacy support for cfg.data.test_path) + # Resolve test paths if cfg.data.test_path: cfg.data.test_image = _combine_path(cfg.data.test_path, cfg.data.test_image) cfg.data.test_label = _combine_path(cfg.data.test_path, cfg.data.test_label) diff --git a/connectomics/data/augment/build.py b/connectomics/data/augment/build.py index 7404e52f..a4325d73 100644 --- a/connectomics/data/augment/build.py +++ b/connectomics/data/augment/build.py @@ -73,7 +73,7 @@ def build_train_transforms( # Load images first (unless using pre-cached dataset) if not skip_loading: # Use appropriate loader based on dataset type - dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility + dataset_type = getattr(cfg.data, "dataset_type", "volume") if dataset_type == "filename": # For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged @@ -94,12 +94,9 @@ def build_train_transforms( transforms.append(ApplyVolumetricSplitd(keys=keys)) # Apply resize if configured (before cropping) - # Check data_transform first (new), then fall back to image_transform.resize (legacy) resize_size = None if hasattr(cfg.data, "data_transform") and hasattr(cfg.data.data_transform, "resize") and cfg.data.data_transform.resize is not None: resize_size = cfg.data.data_transform.resize - elif hasattr(cfg.data.image_transform, "resize") and cfg.data.image_transform.resize is not None: - resize_size = cfg.data.image_transform.resize if resize_size: # Use bilinear for images, nearest for labels/masks @@ -247,7 +244,7 @@ def _build_eval_transforms_impl( transforms = [] # Load images first - use appropriate loader based on dataset type - dataset_type = getattr(cfg.data, "dataset_type", "volume") # Default to volume for backward compatibility + dataset_type = getattr(cfg.data, "dataset_type", "volume") if dataset_type == "filename": # For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged @@ -260,17 +257,15 @@ def _build_eval_transforms_impl( if mode == "val": transpose_axes = cfg.data.val_transpose if cfg.data.val_transpose else [] else: # mode == "test" - # Check both data.test_transpose and inference.data.test_transpose + # Use inference.data.test_transpose transpose_axes = [] - if cfg.data.test_transpose: - transpose_axes = cfg.data.test_transpose if ( hasattr(cfg, "inference") and hasattr(cfg.inference, "data") and hasattr(cfg.inference.data, "test_transpose") and cfg.inference.data.test_transpose ): - transpose_axes = cfg.inference.data.test_transpose # inference takes precedence + transpose_axes = cfg.inference.data.test_transpose transforms.append( LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None) @@ -455,8 +450,8 @@ def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bo List of MONAI transforms """ transforms = [] - - # Get preset mode (default to "some" for backward compatibility) + + # Get preset mode preset = getattr(aug_cfg, "preset", "some") # Helper function to check if augmentation should be applied diff --git a/connectomics/lightning/callbacks.py b/connectomics/lightning/callbacks.py index e78ebc83..a2ac12a4 100644 --- a/connectomics/lightning/callbacks.py +++ b/connectomics/lightning/callbacks.py @@ -428,25 +428,7 @@ def create_callbacks(cfg) -> list: callbacks.append(vis_callback) # Model checkpoint callback - # Support both new unified config (training.checkpoint_*) and old separate config (checkpoint.*) - if hasattr(cfg, 'checkpoint') and cfg.checkpoint is not None: - # Old config style (backward compatibility) - monitor = getattr(cfg.checkpoint, 'monitor', 'val/loss') - default_filename = f'epoch={{epoch:03d}}-{monitor}={{{monitor}:.4f}}' - filename = getattr(cfg.checkpoint, 'filename', default_filename) - - checkpoint_callback = ModelCheckpoint( - monitor=monitor, - mode=getattr(cfg.checkpoint, 'mode', 'min'), - save_top_k=getattr(cfg.checkpoint, 'save_top_k', 3), - save_last=getattr(cfg.checkpoint, 'save_last', True), - dirpath=getattr(cfg.checkpoint, 'dirpath', 'checkpoints'), - filename=filename, - verbose=True - ) - callbacks.append(checkpoint_callback) - elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'): - # New unified config style (monitor.checkpoint.*) + if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'checkpoint'): monitor = getattr(cfg.monitor.checkpoint, 'monitor', 'val/loss') filename = getattr(cfg.monitor.checkpoint, 'filename', None) if filename is None: @@ -465,19 +447,7 @@ def create_callbacks(cfg) -> list: callbacks.append(checkpoint_callback) # Early stopping callback - # Support both new unified config (training.early_stopping_*) and old separate config (early_stopping.*) - if hasattr(cfg, 'early_stopping') and cfg.early_stopping is not None and cfg.early_stopping.enabled: - # Old config style (backward compatibility) - early_stop_callback = EarlyStopping( - monitor=getattr(cfg.early_stopping, 'monitor', 'val/loss'), - patience=getattr(cfg.early_stopping, 'patience', 10), - mode=getattr(cfg.early_stopping, 'mode', 'min'), - min_delta=getattr(cfg.early_stopping, 'min_delta', 0.0), - verbose=True - ) - callbacks.append(early_stop_callback) - elif hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False): - # New unified config style (monitor.early_stopping.*) + if hasattr(cfg, 'monitor') and hasattr(cfg.monitor, 'early_stopping') and getattr(cfg.monitor.early_stopping, 'enabled', False): early_stop_callback = EarlyStopping( monitor=getattr(cfg.monitor.early_stopping, 'monitor', 'val/loss'), patience=getattr(cfg.monitor.early_stopping, 'patience', 10), diff --git a/connectomics/lightning/lit_model.py b/connectomics/lightning/lit_model.py index a7dc451a..8581e026 100644 --- a/connectomics/lightning/lit_model.py +++ b/connectomics/lightning/lit_model.py @@ -172,6 +172,759 @@ def _setup_test_metrics(self): self.test_adapted_rand = True # Flag to enable adapted_rand computation self.test_adapted_rand_results = [] # Store per-batch results + def _setup_sliding_window_inferer(self): + """Initialize MONAI's SlidingWindowInferer based on config.""" + self.sliding_inferer = None + + if not hasattr(self.cfg, 'inference'): + return + + # For 2D models with do_2d=True, disable sliding window inference + if getattr(self.cfg.data, 'do_2d', False): + warnings.warn( + "Sliding-window inference disabled for 2D models with do_2d=True. " + "Using direct inference instead.", + UserWarning, + ) + return + + roi_size = self._resolve_inferer_roi_size() + if roi_size is None: + warnings.warn( + "Sliding-window inference disabled: unable to determine ROI size. " + "Set inference.window_size or model.output_size in the config.", + UserWarning, + ) + return + + overlap = self._resolve_inferer_overlap(roi_size) + # Use system.inference.batch_size as default, fall back to sliding_window.sw_batch_size if specified + system_batch_size = getattr(self.cfg.system.inference, 'batch_size', 1) + config_sw_batch_size = getattr(self.cfg.inference.sliding_window, 'sw_batch_size', None) + sw_batch_size = max(1, int(config_sw_batch_size if config_sw_batch_size is not None else system_batch_size)) + mode = getattr(self.cfg.inference.sliding_window, 'blending', 'gaussian') + sigma_scale = float(getattr(self.cfg.inference.sliding_window, 'sigma_scale', 0.125)) + padding_mode = getattr(self.cfg.inference.sliding_window, 'padding_mode', 'constant') + + self.sliding_inferer = SlidingWindowInferer( + roi_size=roi_size, + sw_batch_size=sw_batch_size, + overlap=overlap, + mode=mode, + sigma_scale=sigma_scale, + padding_mode=padding_mode, + progress=True, + ) + + print( + " Sliding-window inference configured: " + f"roi_size={roi_size}, overlap={overlap}, sw_batch={sw_batch_size}, " + f"mode={mode}, sigma_scale={sigma_scale}, padding={padding_mode}" + ) + + def _resolve_inferer_roi_size(self) -> Optional[Tuple[int, ...]]: + """Determine the ROI size for sliding-window inference.""" + if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'sliding_window'): + window_size = getattr(self.cfg.inference.sliding_window, 'window_size', None) + if window_size: + return tuple(int(v) for v in window_size) + + if hasattr(self.cfg, 'model') and hasattr(self.cfg.model, 'output_size'): + output_size = getattr(self.cfg.model, 'output_size', None) + if output_size: + roi_size = tuple(int(v) for v in output_size) + # For 2D models with do_2d=True, convert to 3D ROI size + if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2: + roi_size = (1,) + roi_size # Add depth dimension + return roi_size + + if hasattr(self.cfg, 'data') and hasattr(self.cfg.data, 'patch_size'): + patch_size = getattr(self.cfg.data, 'patch_size', None) + if patch_size: + roi_size = tuple(int(v) for v in patch_size) + # For 2D models with do_2d=True, convert to 3D ROI size + if getattr(self.cfg.data, 'do_2d', False) and len(roi_size) == 2: + roi_size = (1,) + roi_size # Add depth dimension + return roi_size + + return None + + def _resolve_inferer_overlap(self, roi_size: Tuple[int, ...]) -> Union[float, Tuple[float, ...]]: + """Resolve overlap parameter using inference config.""" + if not hasattr(self.cfg, 'inference') or not hasattr(self.cfg.inference, 'sliding_window'): + return 0.5 + + overlap = getattr(self.cfg.inference.sliding_window, 'overlap', None) + if overlap is not None: + if isinstance(overlap, (list, tuple)): + return tuple(float(max(0.0, min(o, 0.99))) for o in overlap) + return float(max(0.0, min(overlap, 0.99))) + + stride = getattr(self.cfg.inference, 'stride', None) + if stride: + values: List[float] = [] + for size, step in zip(roi_size, stride): + if size <= 0: + values.append(0.0) + continue + ratio = 1.0 - float(step) / float(size) + values.append(float(max(0.0, min(ratio, 0.99)))) + if len(set(values)) == 1: + return values[0] + return tuple(values) + + return 0.5 + + def _extract_main_output(self, outputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: + """Extract the primary segmentation logits from model outputs.""" + if isinstance(outputs, dict): + if 'output' not in outputs: + raise KeyError("Expected key 'output' in model outputs for deep supervision.") + return outputs['output'] + return outputs + + def _sliding_window_predict(self, inputs: torch.Tensor) -> torch.Tensor: + """Wrapper used by MONAI inferer to obtain primary model predictions.""" + outputs = self.forward(inputs) + return self._extract_main_output(outputs) + + def _apply_tta_preprocessing(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Apply activation and channel selection before TTA ensemble. + + Supports per-channel activations via channel_activations config: + Format: [[start_ch, end_ch, 'activation'], ...] + Examples: + - [[0, 2, 'softmax'], [2, 3, 'sigmoid']] # Softmax over channels 0-1, sigmoid for channel 2 + - [[0, 1, 'sigmoid'], [1, 2, 'sigmoid']] # Sigmoid for channels 0 and 1 separately + + Args: + tensor: Raw predictions (B, C, D, H, W) + + Returns: + Preprocessed tensor for TTA ensembling + """ + if not hasattr(self.cfg, 'inference'): + return tensor + + # Check for per-channel activations first (new approach) + channel_activations = getattr(self.cfg.inference.test_time_augmentation, 'channel_activations', None) + + if channel_activations is not None: + # Apply different activations to different channels + activated_channels = [] + for config_entry in channel_activations: + # Only support new format: [start_ch, end_ch, activation] + if len(config_entry) != 3: + raise ValueError( + f"Invalid channel_activations entry: {config_entry}. " + f"Expected [start_ch, end_ch, activation] with 3 elements. " + f"Example: [[0, 2, 'softmax'], [2, 3, 'sigmoid']]" + ) + + start_ch, end_ch, act = config_entry + channel_tensor = tensor[:, start_ch:end_ch, ...] + + if act == 'sigmoid': + channel_tensor = torch.sigmoid(channel_tensor) + elif act == 'tanh': + channel_tensor = torch.tanh(channel_tensor) + elif act == 'softmax': + # Apply softmax across the channel dimension + if end_ch - start_ch > 1: + channel_tensor = torch.softmax(channel_tensor, dim=1) + else: + warnings.warn( + f"Softmax activation for single channel ({start_ch}:{end_ch}) is not meaningful. Skipping.", + UserWarning, + ) + elif act is None or (isinstance(act, str) and act.lower() == 'none'): + # No activation (keep as is) + pass + else: + raise ValueError( + f"Unknown activation '{act}' for channels {start_ch}:{end_ch}. " + f"Supported: 'sigmoid', 'softmax', 'tanh', None" + ) + + activated_channels.append(channel_tensor) + + # Concatenate all channels back together + tensor = torch.cat(activated_channels, dim=1) + + # Get TTA-specific channel selection or fall back to output_channel + tta_channel = getattr(self.cfg.inference.test_time_augmentation, 'select_channel', None) + if tta_channel is None: + tta_channel = getattr(self.cfg.inference, 'output_channel', None) + + # Apply channel selection + if tta_channel is not None: + if isinstance(tta_channel, int): + if tta_channel == -1: + # -1 means all channels + pass + else: + # Single channel selection + tensor = tensor[:, tta_channel:tta_channel+1, ...] + elif isinstance(tta_channel, (list, tuple)): + # Multiple channel selection + tensor = tensor[:, list(tta_channel), ...] + + return tensor + + def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Perform test-time augmentation using flips and ensemble predictions. + + Args: + images: Input volume (B, C, D, H, W) or (B, D, H, W) or (D, H, W) + mask: Optional mask to multiply with predictions after ensemble (B, C, D, H, W) or (B, 1, D, H, W) + + Returns: + Averaged predictions from all TTA variants, optionally masked + """ + # Normalize input to 5D tensor (B, C, D, H, W) + if images.ndim == 3: + # (D, H, W) -> (1, 1, D, H, W) + images = images.unsqueeze(0).unsqueeze(0) + warnings.warn( + f"Input shape {images.shape} (D, H, W) automatically expanded to (1, 1, D, H, W)", + UserWarning, + ) + elif images.ndim == 4: + # (B, D, H, W) -> (B, 1, D, H, W) + images = images.unsqueeze(1) + warnings.warn( + f"Input shape (B, D, H, W) automatically expanded to (B, 1, D, H, W)", + UserWarning, + ) + elif images.ndim != 5: + raise ValueError( + f"TTA requires 3D, 4D, or 5D input tensor. Got {images.ndim}D tensor with shape {images.shape}. " + f"Expected shapes: (D, H, W), (B, D, H, W), or (B, C, D, H, W)" + ) + + # For 2D models with do_2d=True, squeeze the depth dimension if present + if getattr(self.cfg.data, 'do_2d', False) and images.size(2) == 1: # [B, C, 1, H, W] -> [B, C, H, W] + images = images.squeeze(2) + + # Get TTA configuration (default to no augmentation if not configured) + if hasattr(self.cfg, 'inference') and hasattr(self.cfg.inference, 'test_time_augmentation'): + tta_flip_axes_config = getattr(self.cfg.inference.test_time_augmentation, 'flip_axes', None) + else: + tta_flip_axes_config = None # No config = no augmentation, just forward pass + + # Handle different tta_flip_axes configurations + if tta_flip_axes_config is None: + # null: No augmentation, but still apply tta_act and tta_channel (no ensemble) + if self.sliding_inferer is not None: + pred = self.sliding_inferer(inputs=images, network=self._sliding_window_predict) + else: + pred = self._sliding_window_predict(images) + + # Apply TTA preprocessing (activation + channel selection) even without augmentation + ensemble_result = self._apply_tta_preprocessing(pred) + else: + if tta_flip_axes_config == 'all' or tta_flip_axes_config == []: + # "all" or []: All flips (all combinations of spatial axes) + # Determine spatial axes based on data dimensions + if images.dim() == 5: # 3D data: [B, C, D, H, W] + spatial_axes = [1, 2, 3] # [D, H, W] + elif images.dim() == 4: # 2D data: [B, C, H, W] + spatial_axes = [1, 2] # [H, W] + else: + raise ValueError(f"Unsupported data dimensions: {images.dim()}") + + # Generate all combinations of spatial axes + tta_flip_axes = [[]] # No flip baseline + for r in range(1, len(spatial_axes) + 1): + from itertools import combinations + for combo in combinations(spatial_axes, r): + tta_flip_axes.append(list(combo)) + elif isinstance(tta_flip_axes_config, (list, tuple)): + # Custom list: Add no-flip baseline + user-specified flips + tta_flip_axes = [[]] + list(tta_flip_axes_config) + else: + raise ValueError( + f"Invalid tta_flip_axes: {tta_flip_axes_config}. " + f"Expected 'all' (8 flips), null (no aug), or list of flip axes." + ) + + # Apply TTA with flips, preprocessing, and ensembling + predictions = [] + + for flip_axes in tta_flip_axes: + # Apply flip augmentation + if flip_axes: + x_aug = Flip(spatial_axis=flip_axes)(images) + else: + x_aug = images + + # Inference with sliding window + if self.sliding_inferer is not None: + pred = self.sliding_inferer( + inputs=x_aug, + network=self._sliding_window_predict, + ) + else: + pred = self._sliding_window_predict(x_aug) + + # Invert flip for prediction + if flip_axes: + pred = Flip(spatial_axis=flip_axes)(pred) + + # Apply TTA preprocessing (activation + channel selection) if configured + # Note: This is applied BEFORE ensembling for probability-space averaging + pred_processed = self._apply_tta_preprocessing(pred) + + predictions.append(pred_processed) + + # Ensemble predictions based on configured mode + ensemble_mode = getattr(self.cfg.inference.test_time_augmentation, 'ensemble_mode', 'mean') + stacked_preds = torch.stack(predictions, dim=0) + + if ensemble_mode == 'mean': + ensemble_result = stacked_preds.mean(dim=0) + elif ensemble_mode == 'min': + ensemble_result = stacked_preds.min(dim=0)[0] # min returns (values, indices) + elif ensemble_mode == 'max': + ensemble_result = stacked_preds.max(dim=0)[0] # max returns (values, indices) + else: + raise ValueError(f"Unknown TTA ensemble mode: {ensemble_mode}. Use 'mean', 'min', or 'max'.") + + # Apply mask after ensemble if requested + apply_mask = getattr(self.cfg.inference.test_time_augmentation, 'apply_mask', False) + if apply_mask and mask is not None: + # Ensure mask has the same shape as ensemble_result + # mask can be (B, C, D, H, W) with C matching channels, or (B, 1, D, H, W) to broadcast + if mask.shape != ensemble_result.shape: + # If mask is (B, 1, D, H, W), it will broadcast across channels + if mask.shape[1] == 1 and mask.shape[0] == ensemble_result.shape[0]: + # Broadcast across all channels + pass + elif mask.shape[1] != ensemble_result.shape[1]: + warnings.warn( + f"Mask shape {mask.shape} does not match ensemble result shape {ensemble_result.shape}. " + f"Expected mask with C={ensemble_result.shape[1]} or C=1 channels. Skipping mask application.", + UserWarning, + ) + return ensemble_result + + # Multiply each channel by corresponding mask channel (or broadcast if mask has 1 channel) + ensemble_result = ensemble_result * mask + + return ensemble_result + + def _apply_postprocessing(self, data: np.ndarray) -> np.ndarray: + """ + Apply postprocessing transformations to predictions. + + This method applies (in order): + 1. Binary postprocessing (morphological operations, connected components filtering) + 2. Scaling (intensity_scale or output_scale): Multiply predictions by scale factor + 3. Dtype conversion (intensity_dtype or output_dtype): Convert to target dtype with clamping + + Args: + data: Numpy array of predictions (B, C, D, H, W) or (B, D, H, W) + + Returns: + Postprocessed predictions with applied transformations + """ + if not hasattr(self.cfg, 'inference') or not hasattr(self.cfg.inference, 'postprocessing'): + return data + + postprocessing = self.cfg.inference.postprocessing + + # Step 1: Apply binary postprocessing if configured + binary_config = getattr(postprocessing, 'binary', None) + if binary_config is not None and getattr(binary_config, 'enabled', False): + from connectomics.decoding.postprocess import apply_binary_postprocessing + + # Process each sample in batch + # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data + print(f" DEBUG: _apply_postprocessing - input data shape: {data.shape}, ndim: {data.ndim}") + if data.ndim == 4: + # 2D data: (B, C, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_postprocessing - detected 2D data, batch_size: {batch_size}") + elif data.ndim == 5: + # 3D data: (B, C, D, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_postprocessing - detected 3D data, batch_size: {batch_size}") + elif data.ndim == 3: + # Single 3D volume: (C, D, H, W) or (D, H, W) - add batch dimension + batch_size = 1 + data = data[np.newaxis, ...] # (1, C, D, H, W) or (1, D, H, W) + print(f" DEBUG: _apply_postprocessing - single 3D sample, added batch dimension") + elif data.ndim == 2: + # Single 2D image: (H, W) - add batch and channel dimensions + batch_size = 1 + data = data[np.newaxis, np.newaxis, ...] # (1, 1, H, W) + print(f" DEBUG: _apply_postprocessing - single 2D sample, added batch and channel dimensions") + else: + batch_size = 1 + + # Ensure we have at least 4D: (B, ...) where ... can be (C, H, W) for 2D or (C, D, H, W) for 3D + results = [] + for batch_idx in range(batch_size): + sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D + print(f" DEBUG: _apply_postprocessing - processing batch_idx {batch_idx}, sample shape: {sample.shape}") + + # Extract foreground probability (always use first channel if channel dimension exists) + if sample.ndim == 4: # (C, D, H, W) - 3D with channel + foreground_prob = sample[0] # Use first channel -> (D, H, W) + elif sample.ndim == 3: + # Could be (C, H, W) for 2D or (D, H, W) for 3D without channel + # If first dim is small (<=4), assume it's channel (2D), otherwise depth (3D) + if sample.shape[0] <= 4: + foreground_prob = sample[0] # (C, H, W) -> use first channel -> (H, W) + else: + foreground_prob = sample # (D, H, W) - already single channel + elif sample.ndim == 2: # (H, W) - 2D single channel + foreground_prob = sample + else: + foreground_prob = sample + + # Apply binary postprocessing + processed = apply_binary_postprocessing(foreground_prob, binary_config) + + # Expand dims to maintain shape consistency with original sample structure + if sample.ndim == 4: # (C, D, H, W) -> processed is (D, H, W) + processed = processed[np.newaxis, ...] # (1, D, H, W) + elif sample.ndim == 3 and sample.shape[0] <= 4: # (C, H, W) -> processed is (H, W) + processed = processed[np.newaxis, ...] # (1, H, W) + # else: processed is already correct shape (D, H, W) or (H, W) + + results.append(processed) + + # Stack results back into batch + print(f" DEBUG: _apply_postprocessing - stacking {len(results)} results, shapes: {[r.shape for r in results]}") + data = np.stack(results, axis=0) + print(f" DEBUG: _apply_postprocessing - after stacking, data shape: {data.shape}") + + # Step 2: Apply scaling if configured + intensity_scale = getattr(postprocessing, 'intensity_scale', None) + if intensity_scale is not None: + data = data * intensity_scale + + # Step 3: Apply dtype conversion if configured + target_dtype_str = getattr(postprocessing, 'intensity_dtype', None) + + if target_dtype_str is not None and target_dtype_str != 'float32': + # Map string dtype to numpy dtype + dtype_map = { + 'uint8': np.uint8, + 'int8': np.int8, + 'uint16': np.uint16, + 'int16': np.int16, + 'uint32': np.uint32, + 'int32': np.int32, + 'float16': np.float16, + 'float32': np.float32, + 'float64': np.float64, + } + + if target_dtype_str not in dtype_map: + warnings.warn( + f"Unknown dtype '{target_dtype_str}'. Supported: {list(dtype_map.keys())}. " + f"Keeping float32.", + UserWarning, + ) + return data + + target_dtype = dtype_map[target_dtype_str] + + # Clamp to valid range before conversion for integer types + if target_dtype_str == 'uint8': + data = np.clip(data, 0, 255) + elif target_dtype_str == 'int8': + data = np.clip(data, -128, 127) + elif target_dtype_str == 'uint16': + data = np.clip(data, 0, 65535) + elif target_dtype_str == 'int16': + data = np.clip(data, -32768, 32767) + elif target_dtype_str == 'uint32': + data = np.clip(data, 0, 4294967295) + elif target_dtype_str == 'int32': + data = np.clip(data, -2147483648, 2147483647) + + data = data.astype(target_dtype) + + return data + + def _apply_decode_mode(self, data: np.ndarray) -> np.ndarray: + """ + Apply decode mode transformations to convert probability maps to instance segmentation. + + Args: + data: Numpy array of predictions (B, C, D, H, W) or (C, D, H, W) + + Returns: + Decoded segmentation mask(s) + """ + if not hasattr(self.cfg, 'inference'): + return data + + # Access decoding config directly from inference + decode_modes = getattr(self.cfg.inference, 'decoding', None) + + if not decode_modes: + return data + + # Import decoding functions + from connectomics.decoding import ( + decode_binary_thresholding, + decode_binary_cc, + decode_binary_watershed, + decode_binary_contour_cc, + decode_binary_contour_watershed, + decode_binary_contour_distance_watershed, + decode_affinity_cc, + ) + + # Map function names to actual functions + decode_fn_map = { + 'binary_thresholding': decode_binary_thresholding, + 'decode_binary_thresholding': decode_binary_thresholding, + 'decode_binary_cc': decode_binary_cc, + 'decode_binary_watershed': decode_binary_watershed, + 'decode_binary_contour_cc': decode_binary_contour_cc, + 'decode_binary_contour_watershed': decode_binary_contour_watershed, + 'decode_binary_contour_distance_watershed': decode_binary_contour_distance_watershed, + 'decode_affinity_cc': decode_affinity_cc, + } + + # Process each sample in batch + # Handle both 4D (B, C, H, W) for 2D data and 5D (B, C, D, H, W) for 3D data + print(f" DEBUG: _apply_decode_mode - input data shape: {data.shape}, ndim: {data.ndim}") + if data.ndim == 4: + # 2D data: (B, C, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_decode_mode - detected 2D data, batch_size: {batch_size}") + elif data.ndim == 5: + # 3D data: (B, C, D, H, W) + batch_size = data.shape[0] + print(f" DEBUG: _apply_decode_mode - detected 3D data, batch_size: {batch_size}") + else: + # Single sample: add batch dimension + batch_size = 1 + print(f" DEBUG: _apply_decode_mode - single sample, adding batch dimension") + if data.ndim == 3: + data = data[np.newaxis, ...] # (C, H, W) -> (1, C, H, W) + elif data.ndim == 2: + data = data[np.newaxis, np.newaxis, ...] # (H, W) -> (1, 1, H, W) + + results = [] + for batch_idx in range(batch_size): + sample = data[batch_idx] # (C, H, W) for 2D or (C, D, H, W) for 3D + print(f" DEBUG: _apply_decode_mode - processing batch_idx {batch_idx}, sample shape: {sample.shape}") + + # Apply each decode mode sequentially + for decode_cfg in decode_modes: + fn_name = decode_cfg.name if hasattr(decode_cfg, 'name') else decode_cfg.get('name') + kwargs = decode_cfg.kwargs if hasattr(decode_cfg, 'kwargs') else decode_cfg.get('kwargs', {}) + + # Ensure kwargs is a mutable dict (convert from OmegaConf if needed) + if hasattr(kwargs, 'items'): + kwargs = dict(kwargs) + else: + kwargs = {} + + if fn_name not in decode_fn_map: + warnings.warn( + f"Unknown decode function '{fn_name}'. Available: {list(decode_fn_map.keys())}. " + f"Skipping this decode mode.", + UserWarning, + ) + continue + + decode_fn = decode_fn_map[fn_name] + + try: + # Apply decoding function + sample = decode_fn(sample, **kwargs) + # Ensure output has channel dimension for potential chaining + if sample.ndim == 3: + sample = sample[np.newaxis, ...] # (1, D, H, W) + except Exception as e: + warnings.warn( + f"Error applying decode function '{fn_name}': {e}. " + f"Skipping this decode mode.", + UserWarning, + ) + continue + + results.append(sample) + + # Stack results back into batch + # Always preserve batch dimension, even for batch_size=1 + print(f" DEBUG: _apply_decode_mode - stacking {len(results)} results, shapes: {[r.shape for r in results]}") + decoded = np.stack(results, axis=0) + print(f" DEBUG: _apply_decode_mode - final decoded shape: {decoded.shape}") + return decoded + + def _resolve_output_filenames(self, batch: Dict[str, Any]) -> List[str]: + """ + Extract and resolve filenames from batch metadata. + + Args: + batch: Batch dictionary containing metadata and images + + Returns: + List of resolved filenames (without extension) + """ + # Determine batch size from images + images = batch.get('image') + if images is not None: + batch_size = images.shape[0] + else: + # Fallback: try to infer from metadata + batch_size = 1 + + meta = batch.get('image_meta_dict') + filenames: List[Optional[str]] = [] + + print(f" DEBUG: _resolve_output_filenames - meta type: {type(meta)}, batch_size: {batch_size}") + + # Handle different metadata structures + if isinstance(meta, list): + # Multiple metadata dicts (one per sample in batch) + print(f" DEBUG: _resolve_output_filenames - meta is list with {len(meta)} items") + for idx, meta_item in enumerate(meta): + if isinstance(meta_item, dict): + filename = meta_item.get('filename_or_obj') + if filename is not None: + filenames.append(filename) + else: + print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] has no filename_or_obj") + else: + print(f" DEBUG: _resolve_output_filenames - meta_item[{idx}] is not a dict: {type(meta_item)}") + # Update batch_size from metadata if we have a list + batch_size = max(batch_size, len(filenames)) + print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from list") + elif isinstance(meta, dict): + # Single metadata dict + print(f" DEBUG: _resolve_output_filenames - meta is dict") + meta_filenames = meta.get('filename_or_obj') + if isinstance(meta_filenames, (list, tuple)): + filenames = [f for f in meta_filenames if f is not None] + elif meta_filenames is not None: + filenames = [meta_filenames] + # Update batch_size from metadata + if len(filenames) > 0: + batch_size = max(batch_size, len(filenames)) + print(f" DEBUG: _resolve_output_filenames - extracted {len(filenames)} filenames from dict") + else: + # Handle case where meta might be None or other types + # This can happen if metadata wasn't preserved through transforms + # We'll use fallback filenames based on batch_size + print(f" DEBUG: _resolve_output_filenames - meta is None or unexpected type: {type(meta)}") + pass + + resolved_names: List[str] = [] + for idx in range(batch_size): + if idx < len(filenames) and filenames[idx]: + resolved_names.append(Path(str(filenames[idx])).stem) + else: + # Generate fallback filename - this shouldn't happen if metadata is preserved correctly + resolved_names.append(f"volume_{self.global_step}_{idx}") + + print(f" DEBUG: _resolve_output_filenames - returning {len(resolved_names)} resolved names: {resolved_names[:3]}...") + + # Always return exactly batch_size filenames + if len(resolved_names) < batch_size: + print(f" WARNING: _resolve_output_filenames - Only {len(resolved_names)} filenames but batch_size is {batch_size}, padding with fallback names") + while len(resolved_names) < batch_size: + resolved_names.append(f"volume_{self.global_step}_{len(resolved_names)}") + + return resolved_names + + def _write_outputs( + self, + predictions: np.ndarray, + filenames: List[str], + suffix: str = "prediction" + ) -> None: + """ + Persist predictions to disk. + + Args: + predictions: Numpy array of predictions to save (B, C, D, H, W) or (B, D, H, W) + filenames: List of filenames (without extension) for each sample in batch + suffix: Suffix for output filename (default: "prediction") + """ + if not hasattr(self.cfg, 'inference'): + return + + # Access output_path from nested data config + output_dir_value = None + if hasattr(self.cfg.inference, 'data') and hasattr(self.cfg.inference.data, 'output_path'): + output_dir_value = self.cfg.inference.data.output_path + if not output_dir_value: + return + + output_dir = Path(output_dir_value) + output_dir.mkdir(parents=True, exist_ok=True) + + from connectomics.data.io import write_hdf5 + + # Get output transpose from postprocessing config + output_transpose = [] + if hasattr(self.cfg.inference, 'postprocessing'): + output_transpose = getattr(self.cfg.inference.postprocessing, 'output_transpose', []) + + # Determine actual batch size from predictions + # Handle both batched (B, ...) and unbatched (...) predictions + print(f" DEBUG: _write_outputs - predictions shape: {predictions.shape}, ndim: {predictions.ndim}, filenames count: {len(filenames)}") + + if predictions.ndim >= 4: + # Has batch dimension: (B, C, D, H, W) or (B, C, H, W) or (B, D, H, W) + actual_batch_size = predictions.shape[0] + elif predictions.ndim == 3: + # Could be batched 2D data (B, H, W) or single 3D volume (D, H, W) + # Check if first dimension matches number of filenames -> it's batched 2D data + if len(filenames) > 0 and predictions.shape[0] == len(filenames): + # Batched 2D data: (B, H, W) where B matches number of filenames + actual_batch_size = predictions.shape[0] + print(f" DEBUG: _write_outputs - detected batched 2D data (B, H, W) with batch_size={actual_batch_size}") + else: + # Single 3D volume: (D, H, W) - treat as batch_size=1 + actual_batch_size = 1 + predictions = predictions[np.newaxis, ...] # Add batch dimension + print(f" DEBUG: _write_outputs - detected single 3D volume, added batch dimension") + elif predictions.ndim == 2: + # Single 2D image: (H, W) - treat as batch_size=1 + actual_batch_size = 1 + predictions = predictions[np.newaxis, ...] # Add batch dimension + else: + # Unexpected shape, default to batch_size=1 + actual_batch_size = 1 + if predictions.ndim < 2: + predictions = predictions[np.newaxis, ...] # Add batch dimension + + # Ensure we don't exceed the actual batch size + batch_size = min(actual_batch_size, len(filenames)) + print(f" DEBUG: _write_outputs - actual_batch_size: {actual_batch_size}, batch_size: {batch_size}, will save {batch_size} predictions") + + # Save predictions + for idx in range(batch_size): + name = filenames[idx] + prediction = predictions[idx] + + # Apply output transpose if specified + if output_transpose: + if prediction.ndim == 3: + # 3D volume (D, H, W): transpose spatial dimensions + prediction = np.transpose(prediction, output_transpose) + elif prediction.ndim == 4: + # 4D volume (C, D, H, W): keep channel first, transpose spatial dims + spatial_transpose = [i + 1 for i in output_transpose] + prediction = np.transpose(prediction, [0] + spatial_transpose) + + destination = output_dir / f"{name}_{suffix}.h5" + write_hdf5(str(destination), prediction, dataset="prediction") + print(f" Saved {suffix}: {destination}") + def _compute_adapted_rand( self, predictions: np.ndarray, diff --git a/connectomics/models/loss/regularization.py b/connectomics/models/loss/regularization.py index 9d6101d3..5a68060d 100644 --- a/connectomics/models/loss/regularization.py +++ b/connectomics/models/loss/regularization.py @@ -294,14 +294,6 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: return loss.mean() -# Aliases for backward compatibility (if needed in configs) -BinaryReg = BinaryRegularization -FgDTConsistency = ForegroundDistanceConsistency -ContourDTConsistency = ContourDistanceConsistency -FgContourConsistency = ForegroundContourConsistency -NonoverlapReg = NonOverlapRegularization - - __all__ = [ 'BinaryRegularization', 'ForegroundDistanceConsistency', diff --git a/connectomics/utils/visualizer.py b/connectomics/utils/visualizer.py index c3a0895b..228797f4 100755 --- a/connectomics/utils/visualizer.py +++ b/connectomics/utils/visualizer.py @@ -590,9 +590,3 @@ def _should_visualize(self, trainer, batch_idx: int) -> bool: log_every_n_steps = getattr(self.cfg.optimization, "vis_every_n_steps", 100) return trainer.global_step % log_every_n_steps == 0 and batch_idx == 0 - - -# Legacy compatibility -def create_visualizer(cfg, **kwargs): - """Factory function for creating visualizer (backward compatible).""" - return Visualizer(cfg, **kwargs)