Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions synapse_net/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def compute_scale_from_voxel_size(
#


def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons):
def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
from synapse_net.inference.postprocessing import (
segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
)
Expand All @@ -170,6 +170,7 @@ def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons
ref_segmentation = PD if PD.sum() > 0 else ribbon
membrane = segment_membrane_distance_based(
predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
resolution=resolution, min_size=min_membrane_size,
)

segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
Expand All @@ -182,6 +183,8 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=
threshold = kwargs.pop("threshold", 0.5)
n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
n_ribbons = kwargs.pop("n_slices_exclude", 1)
resolution = kwargs.pop("resolution", None)
min_membrane_size = kwargs.pop("min_membrane_size", 0)

predictions = segment_ribbon_synapse_structures(
image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
Expand All @@ -197,7 +200,9 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=
else:
if verbose:
print("Vesicle segmentation was passed, WILL run post-processing.")
segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons)
segmentations = _ribbon_AZ_postprocessing(
predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
)

if return_predictions:
return segmentations, predictions
Expand Down
19 changes: 19 additions & 0 deletions synapse_net/inference/postprocessing/membranes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,19 @@ def segment_membrane_distance_based(
n_slices_exclude: int,
max_distance: float,
resolution: Optional[float] = None,
min_size: int = 0,
):
"""Derive boundary segmentation from boundary predictions by selecting the fragment closest to the ribbon.

Args:
boundary_prediction: Binary prediction for boundaries in the tomogram.
reference_segmentation: The reference segmentation, typically of the ribbon.
n_slices_exclude: The number of slices to exclude on the top / bottom
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
max_distance: The maximal distance from the ribbon to consider.
resolution: The resolution / voxel size of the data.
min_size: The minimal size of a boundary fragment to be included.
"""
assert boundary_prediction.shape == reference_segmentation.shape

original_shape = boundary_prediction.shape
Expand All @@ -95,6 +107,13 @@ def segment_membrane_distance_based(
boundary_prediction = boundary_prediction[slice_mask]
reference_segmentation = reference_segmentation[slice_mask]

if min_size > 0:
boundary_prediction = label(boundary_prediction, block_shape=(32, 256, 256))
ids, sizes = np.unique(boundary_prediction, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
keep_ids = ids[sizes > min_size]
boundary_prediction = np.isin(boundary_prediction, keep_ids)

# Get the unique objects in the reference segmentation.
reference_ids = np.unique(reference_segmentation)
assert reference_ids[0] == 0
Expand Down
5 changes: 4 additions & 1 deletion synapse_net/inference/scalable_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def scalable_segmentation(
prediction: Optional[ArrayLike] = None,
verbose: bool = True,
mask: Optional[ArrayLike] = None,
devices: Optional[List[str]] = None,
) -> None:
"""Run segmentation based on a prediction with foreground and boundary channel.

Expand All @@ -100,6 +101,8 @@ def scalable_segmentation(
If given, this can be a numpy array, a zarr array, or similar
If not given will be stored in a temporary n5 array.
verbose: Whether to print timing information.
devices: The devices for running prediction. If not given will use the GPU
if available, otherwise the CPU.
"""
if mask is not None:
raise NotImplementedError
Expand Down Expand Up @@ -133,5 +136,5 @@ def scalable_segmentation(
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)

# Run prediction and segmentation.
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices)
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)
15 changes: 11 additions & 4 deletions synapse_net/inference/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_prediction(
channels_to_standardize: Optional[List[int]] = None,
mask: Optional[ArrayLike] = None,
prediction: Optional[ArrayLike] = None,
devices: Optional[List[str]] = None,
) -> ArrayLike:
"""Run prediction on a given volume.

Expand All @@ -143,6 +144,8 @@ def get_prediction(
the foreground region of the mask.
prediction: An array like object for writing the prediction.
If not given, the prediction will be computed in moemory.
devices: The devices for running prediction. If not given will use the GPU
if available, otherwise the CPU.

Returns:
The predicted volume.
Expand Down Expand Up @@ -189,7 +192,7 @@ def get_prediction(
# print(f"updated_tiling {updated_tiling}")
prediction = get_prediction_torch_em(
input_volume, updated_tiling, model_path, model, verbose, with_channels,
mask=mask, prediction=prediction,
mask=mask, prediction=prediction, devices=devices,
)

return prediction
Expand All @@ -204,6 +207,7 @@ def get_prediction_torch_em(
with_channels: bool = False,
mask: Optional[ArrayLike] = None,
prediction: Optional[ArrayLike] = None,
devices: Optional[List[str]] = None,
) -> np.ndarray:
"""Run prediction using torch-em on a given volume.

Expand All @@ -218,6 +222,8 @@ def get_prediction_torch_em(
the foreground region of the mask.
prediction: An array like object for writing the prediction.
If not given, the prediction will be computed in moemory.
devices: The devices for running prediction. If not given will use the GPU
if available, otherwise the CPU.

Returns:
The predicted volume.
Expand All @@ -227,14 +233,15 @@ def get_prediction_torch_em(
halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]

t0 = time.time()
device = "cuda" if torch.cuda.is_available() else "cpu"
if devices is None:
devices = ["cuda" if torch.cuda.is_available() else "cpu"]

# Suppress warning when loading the model.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if model is None:
if os.path.isdir(model_path): # Load the model from a torch_em checkpoint.
model = torch_em.util.load_model(checkpoint=model_path, device=device)
model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
else: # Load the model directly from a serialized pytorch model.
model = torch.load(model_path, weights_only=False)

Expand All @@ -253,7 +260,7 @@ def get_prediction_torch_em(

preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
prediction = predict_with_halo(
input_volume, model, gpu_ids=[device],
input_volume, model, gpu_ids=devices,
block_shape=block_shape, halo=halo,
preprocess=preprocess, with_channels=with_channels, mask=mask,
output=prediction,
Expand Down
3 changes: 2 additions & 1 deletion synapse_net/tools/segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def on_predict(self):
# For these models we read out the 'Extra Segmentation' widget.
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
kwargs = {"extra_segmentation": extra_seg}
resolution = tuple(voxel_size[ax] for ax in "zyx")
kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
elif model_type == "cristae": # Cristae model expects 2 3D volumes
kwargs = {
"extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),
Expand Down
Loading