Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vesuvius/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"fsspec>=2025.9.0",
"fvcore>=0.1.5.post20221221",
"huggingface-hub>=0.35.0",
"jupyter>=1.1.1",
"libigl>=2.6.1",
"lxml>=6.0.2",
"magic-class>=0.7.17",
Expand All @@ -41,6 +42,7 @@ dependencies = [
"napari>=0.5.6",
"nest-asyncio>=1.6.0",
"nnunetv2",
"notebook>=7.4.7",
"numba>=0.60.0",
"numcodecs>=0.12.1",
"numpy>=2.0.2",
Expand All @@ -53,7 +55,7 @@ dependencies = [
"psutil>=7.1.0",
"pybind11>=3.0.1",
"pynrrd>=1.1.3",
"pyqt6",
"pyqt5>=5.15.2",
"pytest>=8.4.2",
"pytorch-lightning>=2.5.5",
"pytorch-optimizer>=3.8.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def __init__(self,
do_tube: bool = True,
do_open: bool = False,
do_close: bool = True,
target_keys: Optional[Sequence[str]] = None,):
target_keys: Optional[Sequence[str]] = None,
ignore_values: Optional[dict] = None,):
"""
Calculates the medial surface skeleton of the segmentation (plus an optional 2 px tube around it)
and adds it to the dict with the key "skel"
Expand All @@ -22,6 +23,7 @@ def __init__(self,
self.do_open = do_open
self.do_close = do_close
self.target_keys = tuple(target_keys) if target_keys else None
self.ignore_values = dict(ignore_values or {})

def apply(self, data_dict, **params):
# Collect regression keys to avoid processing continuous aux targets
Expand All @@ -45,8 +47,14 @@ def apply(self, data_dict, **params):
orig_device = t.device
seg_all = t.detach().cpu().numpy()

bin_seg = seg_all > 0
seg_all_skel = np.zeros_like(bin_seg, dtype=np.float32)
ignore_value = self.ignore_values.get(target_key)
if ignore_value is not None:
seg_processed = np.where(seg_all == ignore_value, 0, seg_all)
else:
seg_processed = seg_all

bin_seg = seg_processed > 0
seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32)

for c in range(bin_seg.shape[0]):
seg_c = bin_seg[c]
Expand All @@ -69,7 +77,7 @@ def apply(self, data_dict, **params):
if self.do_close:
skel = closing(skel)

seg_all_skel[c] = (skel.astype(np.float32) * seg_all[c].astype(np.float32))
seg_all_skel[c] = (skel.astype(np.float32) * seg_processed[c].astype(np.float32))

data_dict[f"{target_key}_skel"] = torch.from_numpy(seg_all_skel).to(orig_device)

Expand Down
63 changes: 57 additions & 6 deletions vesuvius/src/vesuvius/models/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def _setup_chunk_slicer(self) -> None:
stride_override = tuple(int(v) for v in stride_override)

channel_selector = self._resolve_label_channel_selector(target_names)
ignore_map = self._resolve_target_ignore_labels(target_names)

config = ChunkSliceConfig(
patch_size=patch_size,
Expand All @@ -383,6 +384,7 @@ def _setup_chunk_slicer(self) -> None:
labels: Dict[str, Optional[np.ndarray]] = {}
label_source = None
cache_key_path: Optional[Path] = None
label_ignore_value: Optional[Union[int, float]] = None

for target_name in target_names:
try:
Expand All @@ -393,10 +395,16 @@ def _setup_chunk_slicer(self) -> None:
) from exc
label_array = self._get_entry_label(entry)
labels[target_name] = label_array
if label_source is None:
source_candidate = self._get_entry_label_source(entry)
if source_candidate is not None:
label_source = source_candidate
ignore_candidate = ignore_map.get(target_name)
source_candidate = self._get_entry_label_source(entry)
if source_candidate is None:
source_candidate = label_array
if label_source is None and source_candidate is not None:
label_source = source_candidate
if label_ignore_value is None and ignore_candidate is not None:
label_ignore_value = ignore_candidate
elif label_ignore_value is None and ignore_candidate is not None:
label_ignore_value = ignore_candidate
if cache_key_path is None:
path_candidate = self._get_entry_label_path(entry)
if path_candidate:
Expand All @@ -410,6 +418,7 @@ def _setup_chunk_slicer(self) -> None:
labels=labels,
label_source=label_source,
cache_key_path=cache_key_path,
label_ignore_value=label_ignore_value,
meshes=reference_info.get('meshes', {}),
)
)
Expand All @@ -426,6 +435,20 @@ def _resolve_label_channel_selector(
return self._normalize_channel_selector(selector)
return None

def _resolve_target_ignore_labels(
self, target_names: Sequence[str]
) -> Dict[str, Union[int, float]]:
ignore_map: Dict[str, Union[int, float]] = {}
for target_name in target_names:
info = self.targets.get(target_name) or {}
for alias in ("ignore_index", "ignore_label", "ignore_value"):
if alias in info:
value = info.get(alias)
if value is not None:
ignore_map[target_name] = value # store first non-null alias per target
break
return ignore_map

@staticmethod
def _normalize_channel_selector(
selector: object,
Expand Down Expand Up @@ -882,7 +905,21 @@ def _create_training_transforms(self):
skeleton_targets = self._skeleton_loss_targets()
if skeleton_targets:
from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform
transforms.append(MedialSurfaceTransform(do_tube=False, target_keys=skeleton_targets))
ignore_values = {}
for target_name in skeleton_targets:
cfg = self.targets.get(target_name, {}) if isinstance(self.targets, dict) else {}
for alias in ("ignore_index", "ignore_label", "ignore_value"):
value = cfg.get(alias)
if value is not None:
ignore_values[target_name] = value
break
transforms.append(
MedialSurfaceTransform(
do_tube=False,
target_keys=skeleton_targets,
ignore_values=ignore_values or None,
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression ignore_values or None is redundant. An empty dict {} is falsy, so this will return None when ignore_values is empty. Consider simplifying to just ignore_values and let the transform handle the empty dict, or use ignore_values if ignore_values else None to be more explicit about the intent.

Copilot uses AI. Check for mistakes.
)
)
print(f"Added MedialSurfaceTransform to training pipeline for targets: {', '.join(skeleton_targets)}")

return ComposeTransforms(transforms)
Expand All @@ -899,7 +936,21 @@ def _create_validation_transforms(self):
from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform

transforms = []
transforms.append(MedialSurfaceTransform(do_tube=False, target_keys=skeleton_targets))
ignore_values = {}
for target_name in skeleton_targets:
cfg = self.targets.get(target_name, {}) if isinstance(self.targets, dict) else {}
for alias in ("ignore_index", "ignore_label", "ignore_value"):
value = cfg.get(alias)
if value is not None:
ignore_values[target_name] = value
break
transforms.append(
MedialSurfaceTransform(
do_tube=False,
target_keys=skeleton_targets,
ignore_values=ignore_values or None,
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression ignore_values or None is redundant. An empty dict {} is falsy, so this will return None when ignore_values is empty. Consider simplifying to just ignore_values and let the transform handle the empty dict, or use ignore_values if ignore_values else None to be more explicit about the intent.

Suggested change
ignore_values=ignore_values or None,
ignore_values=ignore_values if ignore_values else None,

Copilot uses AI. Check for mistakes.
)
)
print(f"Added MedialSurfaceTransform to validation pipeline for targets: {', '.join(skeleton_targets)}")

return ComposeTransforms(transforms)
Expand Down
66 changes: 63 additions & 3 deletions vesuvius/src/vesuvius/models/datasets/find_valid_patches.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import Sequence, Union
from typing import Optional, Sequence, Union

import numpy as np
from tqdm import tqdm
Expand Down Expand Up @@ -146,6 +146,28 @@ def reduce_block_to_scalar(
flat = arr.reshape(spatial_shape + (int(np.prod(extra_shape)),))
return np.linalg.norm(flat, axis=-1)

def zero_ignore_labels(array: np.ndarray, ignore_label: Union[int, float]) -> np.ndarray:
"""Return a copy of ``array`` with the ignore label value zeroed out."""

arr = np.asarray(array)
if arr.size == 0:
return arr

if isinstance(ignore_label, float) and np.isnan(ignore_label):
mask = np.isnan(arr)
else:
try:
mask = (arr == ignore_label)
except TypeError:
Copy link

Copilot AI Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
return arr

if not np.any(mask):
return arr

result = arr.copy()
result[mask] = 0
return result


def check_patch_chunk(
chunk,
Expand All @@ -154,6 +176,7 @@ def check_patch_chunk(
bbox_threshold=0.5,
label_threshold=0.05,
channel_selector: Union[int, Sequence[int], None] = None,
ignore_label: Optional[Union[int, float]] = None,
):
"""Identify valid label patches within a chunk of candidate positions."""

Expand Down Expand Up @@ -182,6 +205,8 @@ def check_patch_chunk(
patch = sheet_label[base_slice + direct_selector]
else:
patch = sheet_label[base_slice]
if ignore_label is not None:
patch = zero_ignore_labels(patch, ignore_label)
patch = collapse_patch_to_spatial(
patch,
spatial_ndim=2,
Expand Down Expand Up @@ -210,6 +235,8 @@ def check_patch_chunk(
patch = sheet_label[base_slice + direct_selector]
else:
patch = sheet_label[base_slice]
if ignore_label is not None:
patch = zero_ignore_labels(patch, ignore_label)
patch = collapse_patch_to_spatial(
patch,
spatial_ndim=3,
Expand Down Expand Up @@ -248,6 +275,7 @@ def find_valid_patches(
num_workers=4,
downsample_level=1,
channel_selectors: Sequence[Union[int, Sequence[int], None]] | None = None,
ignore_labels: Sequence[Optional[Union[int, float]]] | None = None,
):
"""
Finds patches that contain:
Expand All @@ -264,12 +292,15 @@ def find_valid_patches(
max_z, max_y, max_x: maximum coordinates for patch extraction (full resolution)
num_workers: number of processes for parallel processing
downsample_level: Resolution level to use for patch finding (0=full res, 1=2x downsample, etc.)
ignore_labels: Optional per-volume ignore values that should be treated as background.

Returns:
List of dictionaries with 'volume_idx', 'volume_name', and 'start_pos' (coordinates at full resolution)
"""
if len(label_arrays) != len(label_names):
raise ValueError("Number of label arrays must match number of label names")
if ignore_labels is not None and len(ignore_labels) != len(label_arrays):
raise ValueError("ignore_labels must match number of label arrays when provided")

all_valid_patches = []

Expand Down Expand Up @@ -304,6 +335,17 @@ def find_valid_patches(
selector = None
if channel_selectors is not None:
selector = channel_selectors[vol_idx]
ignore_label = None
if ignore_labels is not None:
ignore_label = ignore_labels[vol_idx]

if label_array is None:
logger.warning(
"Volume '%s' has no label array available at index %d; skipping validation",
label_name,
vol_idx,
)
continue

# Access the appropriate resolution level for patch finding
actual_downsample_factor = downsample_factor
Expand Down Expand Up @@ -365,6 +407,12 @@ def _resolve_resolution(array_obj, level_key):
label_name,
actual_downsampled_patch_size,
)
if downsampled_array is None:
logger.warning(
"Volume '%s': unable to resolve label data for validation; skipping",
label_name,
)
continue
except Exception as e:
print(f"Error accessing resolution level {downsample_level} for {label_name}: {e}")
# Fallback to the array itself at full resolution
Expand All @@ -378,6 +426,12 @@ def _resolve_resolution(array_obj, level_key):
label_name,
e,
)
if downsampled_array is None:
logger.warning(
"Volume '%s': no label data available after fallback; skipping",
label_name,
)
continue

# Check if data is 2D or 3D based on patch dimensionality
spatial_ndim = len(actual_downsampled_patch_size)
Expand Down Expand Up @@ -458,7 +512,9 @@ def _resolve_resolution(array_obj, level_key):
x_start = x_group[0]
x_stop = x_group[-1] + dpX

block = downsampled_array[y_start:y_stop, x_start:x_stop]
block = np.asarray(downsampled_array[y_start:y_stop, x_start:x_stop])
if ignore_label is not None:
block = zero_ignore_labels(block, ignore_label)
block = reduce_block_to_scalar(
block,
spatial_ndim=2,
Expand Down Expand Up @@ -532,7 +588,11 @@ def _resolve_resolution(array_obj, level_key):
x_start = x_group[0]
x_stop = x_group[-1] + dpX

block = downsampled_array[z_start:z_stop, y_start:y_stop, x_start:x_stop]
block = np.asarray(
downsampled_array[z_start:z_stop, y_start:y_stop, x_start:x_stop]
)
if ignore_label is not None:
block = zero_ignore_labels(block, ignore_label)
block = reduce_block_to_scalar(
block,
spatial_ndim=3,
Expand Down
Loading
Loading