From eb2d3ed59c0f7450c818b23dd5b46c183657af01 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 12 Aug 2025 22:52:25 +0200 Subject: [PATCH 1/3] Support setting multiple devices in segmentation --- synapse_net/inference/scalable_segmentation.py | 5 ++++- synapse_net/inference/util.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/synapse_net/inference/scalable_segmentation.py b/synapse_net/inference/scalable_segmentation.py index 156fef55..2a0bdc9c 100644 --- a/synapse_net/inference/scalable_segmentation.py +++ b/synapse_net/inference/scalable_segmentation.py @@ -79,6 +79,7 @@ def scalable_segmentation( prediction: Optional[ArrayLike] = None, verbose: bool = True, mask: Optional[ArrayLike] = None, + devices: Optional[List[str]] = None, ) -> None: """Run segmentation based on a prediction with foreground and boundary channel. @@ -100,6 +101,8 @@ def scalable_segmentation( If given, this can be a numpy array, a zarr array, or similar If not given will be stored in a temporary n5 array. verbose: Whether to print timing information. + devices: The devices for running prediction. If not given will use the GPU + if available, otherwise the CPU. """ if mask is not None: raise NotImplementedError @@ -133,5 +136,5 @@ def scalable_segmentation( seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks) # Run prediction and segmentation. - get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose) + get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices) _run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape) diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 6bec9bf4..86fa3db3 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -125,6 +125,7 @@ def get_prediction( channels_to_standardize: Optional[List[int]] = None, mask: Optional[ArrayLike] = None, prediction: Optional[ArrayLike] = None, + devices: Optional[List[str]] = None, ) -> ArrayLike: """Run prediction on a given volume. @@ -143,6 +144,8 @@ def get_prediction( the foreground region of the mask. prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory. + devices: The devices for running prediction. If not given will use the GPU + if available, otherwise the CPU. Returns: The predicted volume. @@ -189,7 +192,7 @@ def get_prediction( # print(f"updated_tiling {updated_tiling}") prediction = get_prediction_torch_em( input_volume, updated_tiling, model_path, model, verbose, with_channels, - mask=mask, prediction=prediction, + mask=mask, prediction=prediction, devices=devices, ) return prediction @@ -204,6 +207,7 @@ def get_prediction_torch_em( with_channels: bool = False, mask: Optional[ArrayLike] = None, prediction: Optional[ArrayLike] = None, + devices: Optional[List[str]] = None, ) -> np.ndarray: """Run prediction using torch-em on a given volume. @@ -218,6 +222,8 @@ def get_prediction_torch_em( the foreground region of the mask. prediction: An array like object for writing the prediction. If not given, the prediction will be computed in moemory. + devices: The devices for running prediction. If not given will use the GPU + if available, otherwise the CPU. Returns: The predicted volume. @@ -227,14 +233,15 @@ def get_prediction_torch_em( halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]] t0 = time.time() - device = "cuda" if torch.cuda.is_available() else "cpu" + if devices is None: + devices = ["cuda" if torch.cuda.is_available() else "cpu"] # Suppress warning when loading the model. with warnings.catch_warnings(): warnings.simplefilter("ignore") if model is None: if os.path.isdir(model_path): # Load the model from a torch_em checkpoint. - model = torch_em.util.load_model(checkpoint=model_path, device=device) + model = torch_em.util.load_model(checkpoint=model_path, device=devices[0]) else: # Load the model directly from a serialized pytorch model. model = torch.load(model_path, weights_only=False) @@ -253,7 +260,7 @@ def get_prediction_torch_em( preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize prediction = predict_with_halo( - input_volume, model, gpu_ids=[device], + input_volume, model, gpu_ids=devices, block_shape=block_shape, halo=halo, preprocess=preprocess, with_channels=with_channels, mask=mask, output=prediction, From efb5851463158a16a5d6cf8feeae5e063f99e79f Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 13 Aug 2025 00:10:44 +0200 Subject: [PATCH 2/3] Update ribbon structure post-processing in segmentation GUI --- synapse_net/inference/inference.py | 9 +++++++-- .../inference/postprocessing/membranes.py | 18 ++++++++++++++++++ synapse_net/tools/segmentation_widget.py | 3 ++- 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py index 47147d71..fd0ee676 100644 --- a/synapse_net/inference/inference.py +++ b/synapse_net/inference/inference.py @@ -155,7 +155,7 @@ def compute_scale_from_voxel_size( # -def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): +def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size): from synapse_net.inference.postprocessing import ( segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, ) @@ -170,6 +170,7 @@ def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons ref_segmentation = PD if PD.sum() > 0 else ribbon membrane = segment_membrane_distance_based( predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, + resolution=resolution, min_size=min_membrane_size, ) segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} @@ -182,6 +183,8 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions= threshold = kwargs.pop("threshold", 0.5) n_slices_exclude = kwargs.pop("n_slices_exclude", 20) n_ribbons = kwargs.pop("n_slices_exclude", 1) + resolution = kwargs.pop("resolution", None) + min_membrane_size = kwargs.pop("min_membrane_size", 0) predictions = segment_ribbon_synapse_structures( image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs @@ -197,7 +200,9 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions= else: if verbose: print("Vesicle segmentation was passed, WILL run post-processing.") - segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) + segmentations = _ribbon_AZ_postprocessing( + predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size + ) if return_predictions: return segmentations, predictions diff --git a/synapse_net/inference/postprocessing/membranes.py b/synapse_net/inference/postprocessing/membranes.py index 0945e0a3..54cb8d19 100644 --- a/synapse_net/inference/postprocessing/membranes.py +++ b/synapse_net/inference/postprocessing/membranes.py @@ -85,7 +85,19 @@ def segment_membrane_distance_based( n_slices_exclude: int, max_distance: float, resolution: Optional[float] = None, + min_size: int = 0, ): + """Derive boundary segmentation from boundary predictions by selecting the fragment closest to the ribbon. + + Args: + boundary_prediction: Binary prediction for boundaries in the tomogram. + reference_segmentation: The reference segmentation, typically of the ribbon. + n_slices_exclude: The number of slices to exclude on the top / bottom + in order to avoid segmentation errors due to imaging artifacts in top and bottom. + max_distance: The maximal distance from the ribbon to consider. + resolution: The resolution / voxel size of the data. + min_size: The minimal size of a boundary fragment to be included. + """ assert boundary_prediction.shape == reference_segmentation.shape original_shape = boundary_prediction.shape @@ -95,6 +107,12 @@ def segment_membrane_distance_based( boundary_prediction = boundary_prediction[slice_mask] reference_segmentation = reference_segmentation[slice_mask] + if min_size > 0: + boundary_prediction = label(boundary_prediction, block_shape=(32, 256, 256)) + ids, sizes = np.unique(boundary_prediction, return_counts=True) + keep_ids = ids[sizes > min_size] + boundary_prediction = np.isin(boundary_prediction, keep_ids) + # Get the unique objects in the reference segmentation. reference_ids = np.unique(reference_segmentation) assert reference_ids[0] == 0 diff --git a/synapse_net/tools/segmentation_widget.py b/synapse_net/tools/segmentation_widget.py index 50fd16e5..27184048 100644 --- a/synapse_net/tools/segmentation_widget.py +++ b/synapse_net/tools/segmentation_widget.py @@ -187,7 +187,8 @@ def on_predict(self): # For these models we read out the 'Extra Segmentation' widget. if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) - kwargs = {"extra_segmentation": extra_seg} + resolution = tuple(voxel_size[ax] for ax in "zyx") + kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000} elif model_type == "cristae": # Cristae model expects 2 3D volumes kwargs = { "extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name), From ee8a0cea27404ca78d99b23d11625285278b95a9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 13 Aug 2025 09:48:32 +0200 Subject: [PATCH 3/3] Fix bug in new membrane post-processing --- synapse_net/inference/postprocessing/membranes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse_net/inference/postprocessing/membranes.py b/synapse_net/inference/postprocessing/membranes.py index 54cb8d19..423e04ff 100644 --- a/synapse_net/inference/postprocessing/membranes.py +++ b/synapse_net/inference/postprocessing/membranes.py @@ -110,6 +110,7 @@ def segment_membrane_distance_based( if min_size > 0: boundary_prediction = label(boundary_prediction, block_shape=(32, 256, 256)) ids, sizes = np.unique(boundary_prediction, return_counts=True) + ids, sizes = ids[1:], sizes[1:] keep_ids = ids[sizes > min_size] boundary_prediction = np.isin(boundary_prediction, keep_ids)