Skip to content

Commit ac9273f

Browse files
authored
Cucim-based extend mask (#22)
* Cucim-based extend mask * Fix SOM problem
1 parent 6c27126 commit ac9273f

File tree

8 files changed

+1056
-518
lines changed

8 files changed

+1056
-518
lines changed

docs/source/tutorials/segment.ipynb

Lines changed: 959 additions & 363 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ packages = ["spatiomic"]
77

88
[project]
99
name = "spatiomic"
10-
version = "0.9.1"
10+
version = "0.9.2"
1111
description = "A python toolbox for spatial omics analysis."
1212
requires-python = ">=3.11"
1313
license = { file = "LICENSE" }

spatiomic/dimension/_som.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090

9191
def set_estimators(
9292
self,
93+
initialize_som: bool = True,
9394
) -> None:
9495
"""Set the XPySOM and nearest neighbor finder estimators."""
9596
# Check whether we can use cupy to work on the GPU
@@ -98,29 +99,30 @@ def set_estimators(
9899
)
99100

100101
# Initialise XPySOM instance
101-
try:
102-
self.som = XPySom(
103-
self.node_count[0],
104-
self.node_count[1],
105-
self.dimension_count,
106-
activation_distance=self.distance_metric,
107-
neighborhood_function=self.neighborhood,
108-
learning_rate=self.learning_rate_initial,
109-
learning_rateN=self.learning_rate_final,
110-
sigma=self.sigma_initial,
111-
sigmaN=self.sigma_final,
112-
xp=self.xp,
113-
n_parallel=self.parallel_count,
114-
random_seed=self.seed,
115-
)
116-
except ValueError as excp:
117-
if self.distance_metric == "correlation":
118-
raise ValueError(
119-
"Using XPySOM with Pearson correlation requires a custom implementation. "
120-
"You can install it via `pip install git+https://github.com/complextissue/xpysom`."
121-
) from excp
122-
else:
123-
raise excp
102+
if initialize_som:
103+
try:
104+
self.som = XPySom(
105+
self.node_count[0],
106+
self.node_count[1],
107+
self.dimension_count,
108+
activation_distance=self.distance_metric,
109+
neighborhood_function=self.neighborhood,
110+
learning_rate=self.learning_rate_initial,
111+
learning_rateN=self.learning_rate_final,
112+
sigma=self.sigma_initial,
113+
sigmaN=self.sigma_final,
114+
xp=self.xp,
115+
n_parallel=self.parallel_count,
116+
random_seed=self.seed,
117+
)
118+
except ValueError as excp:
119+
if self.distance_metric == "correlation":
120+
raise ValueError(
121+
"Using XPySOM with Pearson correlation requires a custom implementation. "
122+
"You can install it via `pip install git+https://github.com/complextissue/xpysom`."
123+
) from excp
124+
else:
125+
raise excp
124126

125127
# Create the nearest neighbor finder
126128
self.neighbor_estimator = get_neighbor_finder(
@@ -366,13 +368,16 @@ def load(
366368
Args:
367369
save_path (str): The path where to load the SOM and its configuration from.
368370
"""
369-
# load and set the som and load the class config
371+
# Load and set the som and load the class config
370372
with open(save_path, "rb") as infile:
371373
config, self.som = pickle.load(infile) # nosec
372374

373-
# initialise an XPySOM object with the data and set the class variables
375+
# Initialise an XPySOM object with the data and set the class variables
374376
self.set_config(**config)
375377

378+
# Re-initialize estimators without re-initializing the SOM
379+
self.set_estimators(initialize_som=False)
380+
376381
def save(
377382
self,
378383
save_path: str,

spatiomic/process/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@
1212
from ._register import Register as register
1313
from ._zscore import ZScore as zscore
1414

15+
standardize = zscore
16+
1517
__all__ = [
1618
"arcsinh",
1719
"clip",
1820
"log1p",
1921
"normalize",
2022
"register",
23+
"standardize",
2124
"zscore",
2225
]

spatiomic/segment/_extend_mask.py

Lines changed: 44 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Mask extension implementation for spatiomic."""
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, cast
44

55
import numpy as np
66
import numpy.typing as npt
@@ -17,53 +17,21 @@ def extend_mask(
1717
) -> npt.NDArray[np.integer[Any]]:
1818
"""Extend segmentation masks by dilating them up to halfway to the nearest neighboring mask.
1919
20-
This function dilates each mask region by a specified number of pixels, but stops dilation
21-
at the halfway point to the nearest neighboring mask to ensure fair distribution of space
22-
between adjacent regions.
20+
Uses Voronoi tessellation via distance transform for O(1) complexity regardless of label count.
21+
Each background pixel is assigned to its nearest mask, constrained by dilation_pixels.
2322
2423
Args:
25-
masks: Input segmentation masks where each unique integer represents a different segment.
26-
Background pixels should have a consistent label (default 0).
27-
dilation_pixels: Maximum number of pixels to dilate each mask. The actual dilation
28-
may be less if masks meet halfway. Must be positive.
24+
masks: 2D input segmentation masks where each unique integer represents a different segment.
25+
dilation_pixels: Maximum number of pixels to dilate each mask. Must be positive.
2926
background_label: The label value representing background pixels. Defaults to 0.
30-
use_gpu: Whether to use GPU acceleration with CuPy. Defaults to False.
27+
use_gpu: Whether to use GPU acceleration with CuPy/cuCIM. Defaults to False.
3128
3229
Returns:
33-
Extended masks with the same shape and dtype as input, where each mask has been
34-
dilated up to the halfway point to neighboring masks.
30+
Extended masks with the same shape and dtype as input.
3531
3632
Raises:
3733
ValueError: If dilation_pixels is not positive or if input is not 2D.
38-
39-
Example:
40-
Basic usage for extending cell masks:
41-
42-
```python
43-
import spatiomic as so
44-
import numpy as np
45-
46-
# Create example segmentation masks
47-
masks = np.array([[0, 0, 0, 0, 0], [0, 1, 0, 2, 0], [0, 0, 0, 0, 0], [0, 3, 0, 4, 0], [0, 0, 0, 0, 0]])
48-
49-
# Extend masks by 1 pixel
50-
extended_masks = so.segment.extend_mask(masks, dilation_pixels=1)
51-
52-
# Each mask will expand towards neighboring masks but stop halfway
53-
```
54-
55-
Note:
56-
- The algorithm uses distance transforms for efficient computation
57-
- Memory usage scales with image size and number of unique labels
58-
- GPU acceleration is available when CuPy is installed
5934
"""
60-
if TYPE_CHECKING or not use_gpu:
61-
xp = np
62-
ndimage_pkg = ndimage
63-
else:
64-
xp = import_package("cupy", alternative=np)
65-
ndimage_pkg = import_package("cupyx.scipy.ndimage", alternative=ndimage)
66-
6735
if dilation_pixels <= 0:
6836
msg = f"dilation_pixels must be positive, got {dilation_pixels}"
6937
raise ValueError(msg)
@@ -72,79 +40,45 @@ def extend_mask(
7240
msg = f"Input masks must be 2D, got {masks.ndim}D"
7341
raise ValueError(msg)
7442

75-
# Convert to GPU array if using CuPy
76-
masks_xp = xp.asarray(masks)
43+
original_dtype = masks.dtype
7744

78-
# Get unique labels excluding background
79-
unique_labels = xp.unique(masks_xp)
80-
unique_labels = unique_labels[unique_labels != background_label]
81-
82-
if len(unique_labels) == 0:
83-
return masks # No masks to extend
45+
if TYPE_CHECKING or not use_gpu:
46+
xp = np
47+
cucim_morphology = None
48+
else:
49+
xp = import_package("cupy", alternative=np)
50+
cucim_morphology = import_package("cucim.core.operations.morphology", alternative=None)
51+
if cucim_morphology is None:
52+
use_gpu = False
53+
xp = np
8454

85-
# Create arrays for efficient vectorized computation
86-
extended_masks = masks_xp.copy()
55+
masks_xp = xp.asarray(masks)
8756
background_mask = masks_xp == background_label
8857

8958
if not xp.any(background_mask):
90-
# No background pixels to extend into
91-
if xp.__name__ == "cupy" and hasattr(extended_masks, "get"):
92-
return extended_masks.get() # type: ignore[no-any-return]
93-
return np.asarray(extended_masks) # type: ignore[no-any-return]
94-
95-
# For efficient computation, we'll use distance transforms
96-
# Create distance maps for each label
97-
label_distances = xp.full((len(unique_labels), *masks_xp.shape), xp.inf, dtype=xp.float32)
98-
99-
for idx, label in enumerate(unique_labels):
100-
# Create binary mask for this label
101-
label_mask = masks_xp == label
102-
103-
if xp.__name__ == "cupy":
104-
# Convert to numpy for scipy operations
105-
label_mask_np = label_mask.get() if hasattr(label_mask, "get") else np.asarray(label_mask)
106-
dist_transform = ndimage.distance_transform_edt(~label_mask_np)
107-
label_distances[idx] = xp.asarray(dist_transform)
108-
else:
109-
dist_transform = ndimage_pkg.distance_transform_edt(~label_mask)
110-
label_distances[idx] = dist_transform
111-
112-
# For each background pixel, find closest two labels and their distances
113-
background_indices = xp.where(background_mask)
114-
115-
if len(background_indices[0]) > 0:
116-
# Vectorized computation for all background pixels
117-
distances_at_bg = label_distances[:, background_indices[0], background_indices[1]]
118-
119-
# Find two closest labels for each background pixel
120-
sorted_indices = xp.argsort(distances_at_bg, axis=0)
121-
closest_distances = xp.take_along_axis(distances_at_bg, sorted_indices, axis=0)
122-
closest_labels = unique_labels[sorted_indices]
123-
124-
# Calculate conditions for assignment
125-
closest_dist = closest_distances[0]
126-
within_dilation = closest_dist <= dilation_pixels
127-
128-
# For pixels with multiple nearby labels, check halfway condition
129-
has_second_neighbor = len(unique_labels) > 1
130-
if has_second_neighbor:
131-
second_closest_dist = closest_distances[1] if len(unique_labels) > 1 else xp.inf
132-
halfway_dist = (closest_dist + second_closest_dist) / 2.0
133-
within_halfway = closest_dist <= halfway_dist
134-
assignment_condition = within_dilation & within_halfway
135-
else:
136-
assignment_condition = within_dilation
137-
138-
# Assign pixels to closest labels where conditions are met
139-
valid_assignments = xp.where(assignment_condition)[0]
140-
if len(valid_assignments) > 0:
141-
assigned_rows = background_indices[0][valid_assignments]
142-
assigned_cols = background_indices[1][valid_assignments]
143-
assigned_labels = closest_labels[0][valid_assignments]
144-
extended_masks[assigned_rows, assigned_cols] = assigned_labels
145-
146-
# Convert back to numpy if we used CuPy
147-
if xp.__name__ == "cupy" and hasattr(extended_masks, "get"):
148-
return extended_masks.get() # type: ignore[no-any-return]
149-
150-
return np.asarray(extended_masks) # type: ignore[no-any-return]
59+
if use_gpu:
60+
return cast(npt.NDArray[np.integer[Any]], masks_xp.get().astype(original_dtype)) # type: ignore[attr-defined]
61+
return masks.copy()
62+
63+
foreground_mask = ~background_mask
64+
65+
if not xp.any(foreground_mask):
66+
if use_gpu:
67+
return cast(npt.NDArray[np.integer[Any]], masks_xp.get().astype(original_dtype)) # type: ignore[attr-defined]
68+
return masks.copy()
69+
70+
if use_gpu and cucim_morphology is not None:
71+
distances, indices = cucim_morphology.distance_transform_edt(background_mask, return_indices=True)
72+
extended = masks_xp[indices[0], indices[1]]
73+
extended = xp.where(distances <= dilation_pixels, extended, background_label) # type: ignore[union-attr]
74+
extended = xp.where(foreground_mask, masks_xp, extended) # type: ignore[union-attr]
75+
del distances, indices, background_mask, foreground_mask, masks_xp
76+
xp.get_default_memory_pool().free_all_blocks() # type: ignore[union-attr]
77+
return cast(npt.NDArray[np.integer[Any]], extended.get().astype(original_dtype))
78+
79+
distances, indices = ndimage.distance_transform_edt(background_mask, return_indices=True)
80+
extended = masks[indices[0], indices[1]]
81+
extended = np.where(distances <= dilation_pixels, extended, background_label)
82+
extended = np.where(foreground_mask, masks, extended)
83+
84+
return cast(npt.NDArray[np.integer[Any]], extended.astype(original_dtype))
6.84 KB
Binary file not shown.

test/dimension/test_som.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,27 +79,27 @@ def test_som_cpu(example_data: NDArray) -> None:
7979

8080
# TODO: add test cases for flattening and returning distances
8181

82-
# test quantization error calculation
83-
quantization_error = data_som.get_quantization_error(example_data)
84-
assert isinstance(quantization_error, float)
82+
# test quantization error calculation
83+
quantization_error = data_som.get_quantization_error(example_data)
84+
assert isinstance(quantization_error, float)
8585

86-
if distance_metric in ["correlation", "cosine"]:
87-
assert quantization_error >= 0.0 and quantization_error <= 1.0, (
88-
f"Quantization error out of bounds for {distance_metric} distance metric: {quantization_error}"
89-
)
86+
if distance_metric in ["correlation", "cosine"]:
87+
assert quantization_error >= 0.0 and quantization_error <= 1.0, (
88+
f"Quantization error out of bounds for {distance_metric} distance metric: {quantization_error}"
89+
)
9090

91-
# test saving and loading
92-
temp_file_name = f"{uuid4()}.p"
93-
temp_file_name = os.path.join(os.path.dirname(os.path.realpath(__file__)), temp_file_name)
94-
data_som.save(save_path=temp_file_name)
91+
# test saving and loading
92+
temp_file_name = f"{uuid4()}.p"
93+
temp_file_name = os.path.join(os.path.dirname(os.path.realpath(__file__)), temp_file_name)
94+
data_som.save(save_path=temp_file_name)
9595

96-
assert os.path.isfile(temp_file_name)
96+
assert os.path.isfile(temp_file_name)
9797

98-
new_som = so.dimension.som()
99-
new_som.load(save_path=temp_file_name)
98+
new_som = so.dimension.som()
99+
new_som.load(save_path=temp_file_name)
100100

101-
assert data_som.get_config() == new_som.get_config()
102-
assert np.all(data_som.get_nodes() == new_som.get_nodes())
101+
assert data_som.get_config() == new_som.get_config()
102+
assert np.all(data_som.get_nodes() == new_som.get_nodes())
103103

104-
# remove the temp file
105-
os.remove(temp_file_name)
104+
# remove the temp file
105+
os.remove(temp_file_name)

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)