11"""Mask extension implementation for spatiomic."""
22
3- from typing import TYPE_CHECKING , Any
3+ from typing import TYPE_CHECKING , Any , cast
44
55import numpy as np
66import numpy .typing as npt
@@ -17,53 +17,21 @@ def extend_mask(
1717) -> npt .NDArray [np .integer [Any ]]:
1818 """Extend segmentation masks by dilating them up to halfway to the nearest neighboring mask.
1919
20- This function dilates each mask region by a specified number of pixels, but stops dilation
21- at the halfway point to the nearest neighboring mask to ensure fair distribution of space
22- between adjacent regions.
20+ Uses Voronoi tessellation via distance transform for O(1) complexity regardless of label count.
21+ Each background pixel is assigned to its nearest mask, constrained by dilation_pixels.
2322
2423 Args:
25- masks: Input segmentation masks where each unique integer represents a different segment.
26- Background pixels should have a consistent label (default 0).
27- dilation_pixels: Maximum number of pixels to dilate each mask. The actual dilation
28- may be less if masks meet halfway. Must be positive.
24+ masks: 2D input segmentation masks where each unique integer represents a different segment.
25+ dilation_pixels: Maximum number of pixels to dilate each mask. Must be positive.
2926 background_label: The label value representing background pixels. Defaults to 0.
30- use_gpu: Whether to use GPU acceleration with CuPy. Defaults to False.
27+ use_gpu: Whether to use GPU acceleration with CuPy/cuCIM . Defaults to False.
3128
3229 Returns:
33- Extended masks with the same shape and dtype as input, where each mask has been
34- dilated up to the halfway point to neighboring masks.
30+ Extended masks with the same shape and dtype as input.
3531
3632 Raises:
3733 ValueError: If dilation_pixels is not positive or if input is not 2D.
38-
39- Example:
40- Basic usage for extending cell masks:
41-
42- ```python
43- import spatiomic as so
44- import numpy as np
45-
46- # Create example segmentation masks
47- masks = np.array([[0, 0, 0, 0, 0], [0, 1, 0, 2, 0], [0, 0, 0, 0, 0], [0, 3, 0, 4, 0], [0, 0, 0, 0, 0]])
48-
49- # Extend masks by 1 pixel
50- extended_masks = so.segment.extend_mask(masks, dilation_pixels=1)
51-
52- # Each mask will expand towards neighboring masks but stop halfway
53- ```
54-
55- Note:
56- - The algorithm uses distance transforms for efficient computation
57- - Memory usage scales with image size and number of unique labels
58- - GPU acceleration is available when CuPy is installed
5934 """
60- if TYPE_CHECKING or not use_gpu :
61- xp = np
62- ndimage_pkg = ndimage
63- else :
64- xp = import_package ("cupy" , alternative = np )
65- ndimage_pkg = import_package ("cupyx.scipy.ndimage" , alternative = ndimage )
66-
6735 if dilation_pixels <= 0 :
6836 msg = f"dilation_pixels must be positive, got { dilation_pixels } "
6937 raise ValueError (msg )
@@ -72,79 +40,45 @@ def extend_mask(
7240 msg = f"Input masks must be 2D, got { masks .ndim } D"
7341 raise ValueError (msg )
7442
75- # Convert to GPU array if using CuPy
76- masks_xp = xp .asarray (masks )
43+ original_dtype = masks .dtype
7744
78- # Get unique labels excluding background
79- unique_labels = xp .unique (masks_xp )
80- unique_labels = unique_labels [unique_labels != background_label ]
81-
82- if len (unique_labels ) == 0 :
83- return masks # No masks to extend
45+ if TYPE_CHECKING or not use_gpu :
46+ xp = np
47+ cucim_morphology = None
48+ else :
49+ xp = import_package ("cupy" , alternative = np )
50+ cucim_morphology = import_package ("cucim.core.operations.morphology" , alternative = None )
51+ if cucim_morphology is None :
52+ use_gpu = False
53+ xp = np
8454
85- # Create arrays for efficient vectorized computation
86- extended_masks = masks_xp .copy ()
55+ masks_xp = xp .asarray (masks )
8756 background_mask = masks_xp == background_label
8857
8958 if not xp .any (background_mask ):
90- # No background pixels to extend into
91- if xp .__name__ == "cupy" and hasattr (extended_masks , "get" ):
92- return extended_masks .get () # type: ignore[no-any-return]
93- return np .asarray (extended_masks ) # type: ignore[no-any-return]
94-
95- # For efficient computation, we'll use distance transforms
96- # Create distance maps for each label
97- label_distances = xp .full ((len (unique_labels ), * masks_xp .shape ), xp .inf , dtype = xp .float32 )
98-
99- for idx , label in enumerate (unique_labels ):
100- # Create binary mask for this label
101- label_mask = masks_xp == label
102-
103- if xp .__name__ == "cupy" :
104- # Convert to numpy for scipy operations
105- label_mask_np = label_mask .get () if hasattr (label_mask , "get" ) else np .asarray (label_mask )
106- dist_transform = ndimage .distance_transform_edt (~ label_mask_np )
107- label_distances [idx ] = xp .asarray (dist_transform )
108- else :
109- dist_transform = ndimage_pkg .distance_transform_edt (~ label_mask )
110- label_distances [idx ] = dist_transform
111-
112- # For each background pixel, find closest two labels and their distances
113- background_indices = xp .where (background_mask )
114-
115- if len (background_indices [0 ]) > 0 :
116- # Vectorized computation for all background pixels
117- distances_at_bg = label_distances [:, background_indices [0 ], background_indices [1 ]]
118-
119- # Find two closest labels for each background pixel
120- sorted_indices = xp .argsort (distances_at_bg , axis = 0 )
121- closest_distances = xp .take_along_axis (distances_at_bg , sorted_indices , axis = 0 )
122- closest_labels = unique_labels [sorted_indices ]
123-
124- # Calculate conditions for assignment
125- closest_dist = closest_distances [0 ]
126- within_dilation = closest_dist <= dilation_pixels
127-
128- # For pixels with multiple nearby labels, check halfway condition
129- has_second_neighbor = len (unique_labels ) > 1
130- if has_second_neighbor :
131- second_closest_dist = closest_distances [1 ] if len (unique_labels ) > 1 else xp .inf
132- halfway_dist = (closest_dist + second_closest_dist ) / 2.0
133- within_halfway = closest_dist <= halfway_dist
134- assignment_condition = within_dilation & within_halfway
135- else :
136- assignment_condition = within_dilation
137-
138- # Assign pixels to closest labels where conditions are met
139- valid_assignments = xp .where (assignment_condition )[0 ]
140- if len (valid_assignments ) > 0 :
141- assigned_rows = background_indices [0 ][valid_assignments ]
142- assigned_cols = background_indices [1 ][valid_assignments ]
143- assigned_labels = closest_labels [0 ][valid_assignments ]
144- extended_masks [assigned_rows , assigned_cols ] = assigned_labels
145-
146- # Convert back to numpy if we used CuPy
147- if xp .__name__ == "cupy" and hasattr (extended_masks , "get" ):
148- return extended_masks .get () # type: ignore[no-any-return]
149-
150- return np .asarray (extended_masks ) # type: ignore[no-any-return]
59+ if use_gpu :
60+ return cast (npt .NDArray [np .integer [Any ]], masks_xp .get ().astype (original_dtype )) # type: ignore[attr-defined]
61+ return masks .copy ()
62+
63+ foreground_mask = ~ background_mask
64+
65+ if not xp .any (foreground_mask ):
66+ if use_gpu :
67+ return cast (npt .NDArray [np .integer [Any ]], masks_xp .get ().astype (original_dtype )) # type: ignore[attr-defined]
68+ return masks .copy ()
69+
70+ if use_gpu and cucim_morphology is not None :
71+ distances , indices = cucim_morphology .distance_transform_edt (background_mask , return_indices = True )
72+ extended = masks_xp [indices [0 ], indices [1 ]]
73+ extended = xp .where (distances <= dilation_pixels , extended , background_label ) # type: ignore[union-attr]
74+ extended = xp .where (foreground_mask , masks_xp , extended ) # type: ignore[union-attr]
75+ del distances , indices , background_mask , foreground_mask , masks_xp
76+ xp .get_default_memory_pool ().free_all_blocks () # type: ignore[union-attr]
77+ return cast (npt .NDArray [np .integer [Any ]], extended .get ().astype (original_dtype ))
78+
79+ distances , indices = ndimage .distance_transform_edt (background_mask , return_indices = True )
80+ extended = masks [indices [0 ], indices [1 ]]
81+ extended = np .where (distances <= dilation_pixels , extended , background_label )
82+ extended = np .where (foreground_mask , masks , extended )
83+
84+ return cast (npt .NDArray [np .integer [Any ]], extended .astype (original_dtype ))
0 commit comments