-
-
Notifications
You must be signed in to change notification settings - Fork 47
more robust ignore label support #543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -357,6 +357,7 @@ def _setup_chunk_slicer(self) -> None: | |||||
| stride_override = tuple(int(v) for v in stride_override) | ||||||
|
|
||||||
| channel_selector = self._resolve_label_channel_selector(target_names) | ||||||
| ignore_map = self._resolve_target_ignore_labels(target_names) | ||||||
|
|
||||||
| config = ChunkSliceConfig( | ||||||
| patch_size=patch_size, | ||||||
|
|
@@ -383,6 +384,7 @@ def _setup_chunk_slicer(self) -> None: | |||||
| labels: Dict[str, Optional[np.ndarray]] = {} | ||||||
| label_source = None | ||||||
| cache_key_path: Optional[Path] = None | ||||||
| label_ignore_value: Optional[Union[int, float]] = None | ||||||
|
|
||||||
| for target_name in target_names: | ||||||
| try: | ||||||
|
|
@@ -393,10 +395,16 @@ def _setup_chunk_slicer(self) -> None: | |||||
| ) from exc | ||||||
| label_array = self._get_entry_label(entry) | ||||||
| labels[target_name] = label_array | ||||||
| if label_source is None: | ||||||
| source_candidate = self._get_entry_label_source(entry) | ||||||
| if source_candidate is not None: | ||||||
| label_source = source_candidate | ||||||
| ignore_candidate = ignore_map.get(target_name) | ||||||
| source_candidate = self._get_entry_label_source(entry) | ||||||
| if source_candidate is None: | ||||||
| source_candidate = label_array | ||||||
| if label_source is None and source_candidate is not None: | ||||||
| label_source = source_candidate | ||||||
| if label_ignore_value is None and ignore_candidate is not None: | ||||||
| label_ignore_value = ignore_candidate | ||||||
| elif label_ignore_value is None and ignore_candidate is not None: | ||||||
| label_ignore_value = ignore_candidate | ||||||
| if cache_key_path is None: | ||||||
| path_candidate = self._get_entry_label_path(entry) | ||||||
| if path_candidate: | ||||||
|
|
@@ -410,6 +418,7 @@ def _setup_chunk_slicer(self) -> None: | |||||
| labels=labels, | ||||||
| label_source=label_source, | ||||||
| cache_key_path=cache_key_path, | ||||||
| label_ignore_value=label_ignore_value, | ||||||
| meshes=reference_info.get('meshes', {}), | ||||||
| ) | ||||||
| ) | ||||||
|
|
@@ -426,6 +435,20 @@ def _resolve_label_channel_selector( | |||||
| return self._normalize_channel_selector(selector) | ||||||
| return None | ||||||
|
|
||||||
| def _resolve_target_ignore_labels( | ||||||
| self, target_names: Sequence[str] | ||||||
| ) -> Dict[str, Union[int, float]]: | ||||||
| ignore_map: Dict[str, Union[int, float]] = {} | ||||||
| for target_name in target_names: | ||||||
| info = self.targets.get(target_name) or {} | ||||||
| for alias in ("ignore_index", "ignore_label", "ignore_value"): | ||||||
| if alias in info: | ||||||
| value = info.get(alias) | ||||||
| if value is not None: | ||||||
| ignore_map[target_name] = value # store first non-null alias per target | ||||||
| break | ||||||
| return ignore_map | ||||||
|
|
||||||
| @staticmethod | ||||||
| def _normalize_channel_selector( | ||||||
| selector: object, | ||||||
|
|
@@ -882,7 +905,21 @@ def _create_training_transforms(self): | |||||
| skeleton_targets = self._skeleton_loss_targets() | ||||||
| if skeleton_targets: | ||||||
| from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform | ||||||
| transforms.append(MedialSurfaceTransform(do_tube=False, target_keys=skeleton_targets)) | ||||||
| ignore_values = {} | ||||||
| for target_name in skeleton_targets: | ||||||
| cfg = self.targets.get(target_name, {}) if isinstance(self.targets, dict) else {} | ||||||
| for alias in ("ignore_index", "ignore_label", "ignore_value"): | ||||||
| value = cfg.get(alias) | ||||||
| if value is not None: | ||||||
| ignore_values[target_name] = value | ||||||
| break | ||||||
| transforms.append( | ||||||
| MedialSurfaceTransform( | ||||||
| do_tube=False, | ||||||
| target_keys=skeleton_targets, | ||||||
| ignore_values=ignore_values or None, | ||||||
| ) | ||||||
| ) | ||||||
| print(f"Added MedialSurfaceTransform to training pipeline for targets: {', '.join(skeleton_targets)}") | ||||||
|
|
||||||
| return ComposeTransforms(transforms) | ||||||
|
|
@@ -899,7 +936,21 @@ def _create_validation_transforms(self): | |||||
| from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform | ||||||
|
|
||||||
| transforms = [] | ||||||
| transforms.append(MedialSurfaceTransform(do_tube=False, target_keys=skeleton_targets)) | ||||||
| ignore_values = {} | ||||||
| for target_name in skeleton_targets: | ||||||
| cfg = self.targets.get(target_name, {}) if isinstance(self.targets, dict) else {} | ||||||
| for alias in ("ignore_index", "ignore_label", "ignore_value"): | ||||||
| value = cfg.get(alias) | ||||||
| if value is not None: | ||||||
| ignore_values[target_name] = value | ||||||
| break | ||||||
| transforms.append( | ||||||
| MedialSurfaceTransform( | ||||||
| do_tube=False, | ||||||
| target_keys=skeleton_targets, | ||||||
| ignore_values=ignore_values or None, | ||||||
|
||||||
| ignore_values=ignore_values or None, | |
| ignore_values=ignore_values if ignore_values else None, |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import logging | ||
| import time | ||
| from typing import Sequence, Union | ||
| from typing import Optional, Sequence, Union | ||
|
|
||
| import numpy as np | ||
| from tqdm import tqdm | ||
|
|
@@ -146,6 +146,28 @@ def reduce_block_to_scalar( | |
| flat = arr.reshape(spatial_shape + (int(np.prod(extra_shape)),)) | ||
| return np.linalg.norm(flat, axis=-1) | ||
|
|
||
| def zero_ignore_labels(array: np.ndarray, ignore_label: Union[int, float]) -> np.ndarray: | ||
| """Return a copy of ``array`` with the ignore label value zeroed out.""" | ||
|
|
||
| arr = np.asarray(array) | ||
| if arr.size == 0: | ||
| return arr | ||
|
|
||
| if isinstance(ignore_label, float) and np.isnan(ignore_label): | ||
| mask = np.isnan(arr) | ||
| else: | ||
| try: | ||
| mask = (arr == ignore_label) | ||
| except TypeError: | ||
|
||
| return arr | ||
|
|
||
| if not np.any(mask): | ||
| return arr | ||
|
|
||
| result = arr.copy() | ||
| result[mask] = 0 | ||
| return result | ||
|
|
||
|
|
||
| def check_patch_chunk( | ||
| chunk, | ||
|
|
@@ -154,6 +176,7 @@ def check_patch_chunk( | |
| bbox_threshold=0.5, | ||
| label_threshold=0.05, | ||
| channel_selector: Union[int, Sequence[int], None] = None, | ||
| ignore_label: Optional[Union[int, float]] = None, | ||
| ): | ||
| """Identify valid label patches within a chunk of candidate positions.""" | ||
|
|
||
|
|
@@ -182,6 +205,8 @@ def check_patch_chunk( | |
| patch = sheet_label[base_slice + direct_selector] | ||
| else: | ||
| patch = sheet_label[base_slice] | ||
| if ignore_label is not None: | ||
| patch = zero_ignore_labels(patch, ignore_label) | ||
| patch = collapse_patch_to_spatial( | ||
| patch, | ||
| spatial_ndim=2, | ||
|
|
@@ -210,6 +235,8 @@ def check_patch_chunk( | |
| patch = sheet_label[base_slice + direct_selector] | ||
| else: | ||
| patch = sheet_label[base_slice] | ||
| if ignore_label is not None: | ||
| patch = zero_ignore_labels(patch, ignore_label) | ||
| patch = collapse_patch_to_spatial( | ||
| patch, | ||
| spatial_ndim=3, | ||
|
|
@@ -248,6 +275,7 @@ def find_valid_patches( | |
| num_workers=4, | ||
| downsample_level=1, | ||
| channel_selectors: Sequence[Union[int, Sequence[int], None]] | None = None, | ||
| ignore_labels: Sequence[Optional[Union[int, float]]] | None = None, | ||
| ): | ||
| """ | ||
| Finds patches that contain: | ||
|
|
@@ -264,12 +292,15 @@ def find_valid_patches( | |
| max_z, max_y, max_x: maximum coordinates for patch extraction (full resolution) | ||
| num_workers: number of processes for parallel processing | ||
| downsample_level: Resolution level to use for patch finding (0=full res, 1=2x downsample, etc.) | ||
| ignore_labels: Optional per-volume ignore values that should be treated as background. | ||
|
|
||
| Returns: | ||
| List of dictionaries with 'volume_idx', 'volume_name', and 'start_pos' (coordinates at full resolution) | ||
| """ | ||
| if len(label_arrays) != len(label_names): | ||
| raise ValueError("Number of label arrays must match number of label names") | ||
| if ignore_labels is not None and len(ignore_labels) != len(label_arrays): | ||
| raise ValueError("ignore_labels must match number of label arrays when provided") | ||
|
|
||
| all_valid_patches = [] | ||
|
|
||
|
|
@@ -304,6 +335,17 @@ def find_valid_patches( | |
| selector = None | ||
| if channel_selectors is not None: | ||
| selector = channel_selectors[vol_idx] | ||
| ignore_label = None | ||
| if ignore_labels is not None: | ||
| ignore_label = ignore_labels[vol_idx] | ||
|
|
||
| if label_array is None: | ||
| logger.warning( | ||
| "Volume '%s' has no label array available at index %d; skipping validation", | ||
| label_name, | ||
| vol_idx, | ||
| ) | ||
| continue | ||
|
|
||
| # Access the appropriate resolution level for patch finding | ||
| actual_downsample_factor = downsample_factor | ||
|
|
@@ -365,6 +407,12 @@ def _resolve_resolution(array_obj, level_key): | |
| label_name, | ||
| actual_downsampled_patch_size, | ||
| ) | ||
| if downsampled_array is None: | ||
| logger.warning( | ||
| "Volume '%s': unable to resolve label data for validation; skipping", | ||
| label_name, | ||
| ) | ||
| continue | ||
| except Exception as e: | ||
| print(f"Error accessing resolution level {downsample_level} for {label_name}: {e}") | ||
| # Fallback to the array itself at full resolution | ||
|
|
@@ -378,6 +426,12 @@ def _resolve_resolution(array_obj, level_key): | |
| label_name, | ||
| e, | ||
| ) | ||
| if downsampled_array is None: | ||
| logger.warning( | ||
| "Volume '%s': no label data available after fallback; skipping", | ||
| label_name, | ||
| ) | ||
| continue | ||
|
|
||
| # Check if data is 2D or 3D based on patch dimensionality | ||
| spatial_ndim = len(actual_downsampled_patch_size) | ||
|
|
@@ -458,7 +512,9 @@ def _resolve_resolution(array_obj, level_key): | |
| x_start = x_group[0] | ||
| x_stop = x_group[-1] + dpX | ||
|
|
||
| block = downsampled_array[y_start:y_stop, x_start:x_stop] | ||
| block = np.asarray(downsampled_array[y_start:y_stop, x_start:x_stop]) | ||
| if ignore_label is not None: | ||
| block = zero_ignore_labels(block, ignore_label) | ||
| block = reduce_block_to_scalar( | ||
| block, | ||
| spatial_ndim=2, | ||
|
|
@@ -532,7 +588,11 @@ def _resolve_resolution(array_obj, level_key): | |
| x_start = x_group[0] | ||
| x_stop = x_group[-1] + dpX | ||
|
|
||
| block = downsampled_array[z_start:z_stop, y_start:y_stop, x_start:x_stop] | ||
| block = np.asarray( | ||
| downsampled_array[z_start:z_stop, y_start:y_stop, x_start:x_stop] | ||
| ) | ||
| if ignore_label is not None: | ||
| block = zero_ignore_labels(block, ignore_label) | ||
| block = reduce_block_to_scalar( | ||
| block, | ||
| spatial_ndim=3, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expression
ignore_values or Noneis redundant. An empty dict{}is falsy, so this will returnNonewhenignore_valuesis empty. Consider simplifying to justignore_valuesand let the transform handle the empty dict, or useignore_values if ignore_values else Noneto be more explicit about the intent.