Skip to content

Commit c6aecce

Browse files
author
Donglai Wei
committed
add binary segmentation decoding and postprocessing
1 parent 3bb37a6 commit c6aecce

File tree

11 files changed

+339
-147
lines changed

11 files changed

+339
-147
lines changed

connectomics/config/hydra_config.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,8 @@ class AffineConfig:
628628
enabled: Optional[bool] = None
629629
prob: float = 0.5
630630
rotate_range: Tuple[float, float, float] = (0.2, 0.2, 0.2) # Rotation range in radians (~11°)
631-
scale_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Scaling range (±10%)
632-
shear_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Shearing range (±10%)
631+
scale_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Scaling range (±10%)
632+
shear_range: Tuple[float, float, float] = (0.1, 0.1, 0.1) # Shearing range (±10%)
633633

634634

635635
@dataclass
@@ -913,16 +913,64 @@ class DecodeModeConfig:
913913
) # Keyword arguments for the decode function
914914

915915

916+
@dataclass
917+
class BinaryPostprocessingConfig:
918+
"""Binary segmentation postprocessing configuration.
919+
920+
Applies morphological operations and connected components filtering to binary predictions.
921+
Pipeline order:
922+
1. Threshold predictions to binary mask (handled by decoding)
923+
2. Apply morphological opening (erosion + dilation)
924+
3. Extract connected components
925+
4. Keep top-k largest components
926+
"""
927+
928+
enabled: bool = False # Enable binary postprocessing pipeline
929+
median_filter_size: Optional[Tuple[int, ...]] = (
930+
None # Median filter kernel size (e.g., (3, 3) for 2D)
931+
)
932+
opening_iterations: int = 0 # Number of morphological opening iterations
933+
closing_iterations: int = 0 # Number of morphological closing iterations
934+
connected_components: Optional["ConnectedComponentsConfig"] = None # CC filtering config
935+
936+
937+
@dataclass
938+
class ConnectedComponentsConfig:
939+
"""Connected components filtering configuration."""
940+
941+
enabled: bool = True # Enable connected components filtering
942+
top_k: Optional[int] = None # Keep only top-k largest components (None = keep all)
943+
min_size: int = 0 # Minimum component size in voxels
944+
connectivity: int = 1 # Connectivity for CC (1=face, 2=face+edge, 3=face+edge+corner)
945+
946+
916947
@dataclass
917948
class PostprocessingConfig:
918-
"""Postprocessing configuration."""
949+
"""Postprocessing configuration for inference output.
950+
951+
Controls how predictions are transformed before saving:
952+
- Thresholding: Binarize predictions using threshold_range
953+
- Scaling: Multiply intensity values (e.g., 255 for uint8)
954+
- Dtype conversion: Convert to target data type with proper clamping
955+
- Transpose: Reorder axes (e.g., [2,1,0] for zyx->xyz)
956+
"""
919957

920-
output_scale: Optional[float] = (
958+
# Thresholding configuration
959+
binary: Optional[Dict[str, Any]] = field(
960+
default_factory=dict
961+
) # Binary thresholding config (e.g., {'threshold_range': [0.5, 1.0]})
962+
963+
# Intensity scaling
964+
intensity_scale: Optional[float] = (
921965
None # Scale predictions for saving (e.g., 255.0 for uint8). None = no scaling
922966
)
923-
output_dtype: Optional[str] = (
967+
968+
# Data type conversion
969+
intensity_dtype: Optional[str] = (
924970
None # Output data type: 'uint8', 'uint16', 'float32'. None = no conversion (keep as-is)
925971
)
972+
973+
# Axis permutation
926974
output_transpose: List[int] = field(
927975
default_factory=list
928976
) # Axis permutation for output (e.g., [2,1,0] for zyx->xyz)
@@ -1072,6 +1120,8 @@ def configure_instance_segmentation(cfg: Config, boundary_thickness: int = 5) ->
10721120
"SlidingWindowConfig",
10731121
"TestTimeAugmentationConfig",
10741122
"PostprocessingConfig",
1123+
"BinaryPostprocessingConfig",
1124+
"ConnectedComponentsConfig",
10751125
"EvaluationConfig",
10761126
"DecodeModeConfig",
10771127
"DecodeBinaryContourDistanceWatershedConfig",

connectomics/data/process/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Core processing functions
22
from .bbox_processor import * # New: unified bbox processing framework
33

4+
# Utility functions used by decoding
5+
from .misc import get_seg_type
6+
from .bbox import bbox_ND, crop_ND, replace_ND
7+
48
# MONAI-native transforms and composition
59

610
# Pipeline builder (primary entry point for label transforms)

connectomics/decoding/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020

2121
from .segmentation import (
22+
decode_binary_thresholding,
2223
decode_binary_cc,
2324
decode_binary_watershed,
2425
decode_binary_contour_cc,
@@ -58,6 +59,7 @@
5859

5960
__all__ = [
6061
# Segmentation decoding
62+
'decode_binary_thresholding',
6163
'decode_binary_cc',
6264
'decode_binary_watershed',
6365
'decode_binary_contour_cc',

connectomics/decoding/postprocess.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"watershed_split",
3333
"stitch_3d",
3434
"intersection_over_union",
35+
"apply_binary_postprocessing",
3536
]
3637

3738

@@ -248,3 +249,108 @@ def _label_overlap(x: np.ndarray, y: np.ndarray) -> np.ndarray:
248249
for i in range(len(x)):
249250
overlap[x[i], y[i]] += 1
250251
return overlap
252+
253+
254+
def apply_binary_postprocessing(
255+
pred: np.ndarray, config: 'BinaryPostprocessingConfig'
256+
) -> np.ndarray:
257+
"""Apply binary segmentation postprocessing pipeline.
258+
259+
Pipeline order:
260+
1. Threshold predictions to binary mask using threshold_range
261+
2. Apply median filter (optional)
262+
3. Apply morphological opening (erosion + dilation)
263+
4. Apply morphological closing (dilation + erosion)
264+
5. Extract connected components and filter by size/keep top-k
265+
266+
Args:
267+
pred (numpy.ndarray): Predicted foreground probability in range [0, 1].
268+
Shape can be 2D (H, W) or 3D (D, H, W).
269+
config (BinaryPostprocessingConfig): Configuration for postprocessing pipeline.
270+
271+
Returns:
272+
numpy.ndarray: Postprocessed binary mask (same shape as input).
273+
Values: 0 (background) or 1 (foreground).
274+
275+
Example:
276+
>>> from connectomics.config import BinaryPostprocessingConfig, ConnectedComponentsConfig
277+
>>> config = BinaryPostprocessingConfig(
278+
... enabled=True,
279+
... threshold_range=(0.8, 1.0),
280+
... opening_iterations=2,
281+
... connected_components=ConnectedComponentsConfig(top_k=1)
282+
... )
283+
>>> pred = np.random.rand(128, 128) # Random probabilities
284+
>>> binary_mask = apply_binary_postprocessing(pred, config)
285+
"""
286+
if not config or not config.enabled:
287+
# Just threshold at 0.5 if postprocessing is disabled
288+
return (pred > 0.5).astype(np.uint8)
289+
290+
# Step 1: Threshold to binary using threshold_range
291+
# Use the minimum threshold from the range
292+
threshold = config.threshold_range[0]
293+
binary = (pred > threshold).astype(np.uint8)
294+
295+
# Step 2: Apply median filter (optional noise reduction)
296+
if config.median_filter_size is not None:
297+
binary = ndimage.median_filter(binary, size=config.median_filter_size)
298+
299+
# Step 3: Morphological opening (erosion + dilation) - removes small objects
300+
if config.opening_iterations > 0:
301+
binary = ndimage.binary_opening(
302+
binary, iterations=config.opening_iterations
303+
).astype(np.uint8)
304+
305+
# Step 4: Morphological closing (dilation + erosion) - fills small holes
306+
if config.closing_iterations > 0:
307+
binary = ndimage.binary_closing(
308+
binary, iterations=config.closing_iterations
309+
).astype(np.uint8)
310+
311+
# Step 5: Connected components filtering
312+
if config.connected_components is not None and config.connected_components.enabled:
313+
cc_config = config.connected_components
314+
315+
# Extract connected components
316+
connectivity = cc_config.connectivity
317+
labels = cc3d.connected_components(binary, connectivity=connectivity)
318+
319+
# Get component sizes
320+
component_sizes = np.bincount(labels.ravel())
321+
# Skip background (label 0)
322+
component_sizes[0] = 0
323+
324+
# Filter by minimum size
325+
if cc_config.min_size > 0:
326+
small_components = np.where(component_sizes < cc_config.min_size)[0]
327+
for label_id in small_components:
328+
labels[labels == label_id] = 0
329+
330+
# Keep only top-k largest components
331+
if cc_config.top_k is not None and cc_config.top_k > 0:
332+
# Get sizes (excluding background)
333+
sizes = component_sizes[1:] # Skip background
334+
label_ids = np.arange(1, len(component_sizes))
335+
336+
if len(sizes) > cc_config.top_k:
337+
# Get indices of top-k largest components
338+
top_k_indices = np.argsort(sizes)[-cc_config.top_k:]
339+
top_k_labels = label_ids[top_k_indices]
340+
341+
# Create mask keeping only top-k
342+
keep_mask = np.zeros_like(labels, dtype=bool)
343+
for label_id in top_k_labels:
344+
keep_mask |= (labels == label_id)
345+
346+
labels = keep_mask.astype(np.uint8)
347+
else:
348+
# Convert labels back to binary (0/1)
349+
labels = (labels > 0).astype(np.uint8)
350+
else:
351+
# Convert labels back to binary (0/1)
352+
labels = (labels > 0).astype(np.uint8)
353+
354+
binary = labels
355+
356+
return binary

connectomics/decoding/segmentation.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def decorator(func):
4343

4444

4545
__all__ = [
46+
"decode_binary_thresholding",
4647
"decode_binary_cc",
4748
"decode_binary_watershed",
4849
"decode_binary_contour_cc",
@@ -52,6 +53,80 @@ def decorator(func):
5253
]
5354

5455

56+
def decode_binary_thresholding(
57+
predictions: np.ndarray,
58+
threshold_range: Tuple[float, float] = (0.8, 1.0),
59+
) -> np.ndarray:
60+
r"""Convert binary foreground probability maps to binary mask via simple thresholding.
61+
62+
This is a lightweight decoding function that applies thresholding to convert
63+
probability predictions to binary segmentation masks. Unlike instance segmentation
64+
methods, this produces a semantic segmentation (no individual instance IDs).
65+
66+
The function uses the minimum threshold from threshold_range to binarize predictions.
67+
This is useful for simple binary segmentation tasks where instance separation is not needed.
68+
69+
Args:
70+
predictions (numpy.ndarray): foreground probability of shape :math:`(C, Z, Y, X)` or :math:`(C, Y, X)`.
71+
The first channel (predictions[0]) is used as the foreground probability.
72+
Values should be in range [0, 1] (normalized) or [0, 255] (uint8).
73+
threshold_range (tuple): Tuple of (min_threshold, max_threshold) for binarization.
74+
Only the minimum threshold is used. Values >= min_threshold become foreground (1).
75+
Default: (0.8, 1.0)
76+
77+
Returns:
78+
numpy.ndarray: Binary segmentation mask with shape matching input spatial dimensions.
79+
Values: 0 (background) or 1 (foreground).
80+
For 3D input: shape :math:`(Z, Y, X)`
81+
For 2D input: shape :math:`(Y, X)`
82+
83+
Examples:
84+
>>> # 3D predictions (normalized [0, 1])
85+
>>> predictions = np.random.rand(2, 64, 128, 128) # (C, Z, Y, X)
86+
>>> binary_mask = decode_binary_thresholding(predictions, threshold_range=(0.8, 1.0))
87+
>>> print(binary_mask.shape) # (64, 128, 128)
88+
>>> print(np.unique(binary_mask)) # [0, 1]
89+
90+
>>> # 3D predictions (uint8 [0, 255])
91+
>>> predictions = np.random.randint(0, 256, (2, 64, 128, 128), dtype=np.uint8)
92+
>>> binary_mask = decode_binary_thresholding(predictions, threshold_range=(0.8, 1.0))
93+
94+
>>> # 2D predictions
95+
>>> predictions = np.random.rand(2, 512, 512) # (C, Y, X)
96+
>>> binary_mask = decode_binary_thresholding(predictions, threshold_range=(0.5, 1.0))
97+
>>> print(binary_mask.shape) # (512, 512)
98+
99+
Note:
100+
- **Auto-detection of value range**: Automatically handles both normalized [0, 1]
101+
and uint8 [0, 255] predictions
102+
- **2D/3D support**: Works with both 2D (C, Y, X) and 3D (C, Z, Y, X) inputs
103+
- **Channel 0 usage**: Uses first channel (predictions[0]) as foreground probability
104+
- **Simple thresholding**: No morphological operations or connected components
105+
- **Post-processing**: Use binary postprocessing config for refinement (opening/closing/CC filtering)
106+
107+
See Also:
108+
- :func:`decode_binary_cc`: Binary threshold + connected components (instance segmentation)
109+
- :func:`decode_binary_watershed`: Binary threshold + watershed (instance segmentation)
110+
- :class:`connectomics.config.BinaryPostprocessingConfig`: For morphological refinement
111+
"""
112+
# Extract foreground probability (first channel)
113+
semantic = predictions[0]
114+
115+
# Auto-detect if predictions are in [0, 1] or [0, 255] range
116+
max_value = np.max(semantic)
117+
if max_value > 1.0:
118+
# Assume uint8 range [0, 255]
119+
threshold = threshold_range[0] * 255
120+
else:
121+
# Assume normalized range [0, 1]
122+
threshold = threshold_range[0]
123+
124+
# Apply thresholding
125+
binary_mask = (semantic > threshold).astype(np.uint8)
126+
127+
return binary_mask
128+
129+
55130
def decode_binary_cc(
56131
predictions: np.ndarray,
57132
foreground_threshold: float = 0.8,

0 commit comments

Comments
 (0)