Skip to content

Commit e8359ee

Browse files
Merge pull request #136 from computational-cell-analytics/minor-updates
Minor updates
2 parents 9c252ed + ee8a0ce commit e8359ee

File tree

5 files changed

+43
-8
lines changed

5 files changed

+43
-8
lines changed

synapse_net/inference/inference.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def compute_scale_from_voxel_size(
155155
#
156156

157157

158-
def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons):
158+
def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size):
159159
from synapse_net.inference.postprocessing import (
160160
segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based,
161161
)
@@ -170,6 +170,7 @@ def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons
170170
ref_segmentation = PD if PD.sum() > 0 else ribbon
171171
membrane = segment_membrane_distance_based(
172172
predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude,
173+
resolution=resolution, min_size=min_membrane_size,
173174
)
174175

175176
segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane}
@@ -182,6 +183,8 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=
182183
threshold = kwargs.pop("threshold", 0.5)
183184
n_slices_exclude = kwargs.pop("n_slices_exclude", 20)
184185
n_ribbons = kwargs.pop("n_slices_exclude", 1)
186+
resolution = kwargs.pop("resolution", None)
187+
min_membrane_size = kwargs.pop("min_membrane_size", 0)
185188

186189
predictions = segment_ribbon_synapse_structures(
187190
image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs
@@ -197,7 +200,9 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=
197200
else:
198201
if verbose:
199202
print("Vesicle segmentation was passed, WILL run post-processing.")
200-
segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons)
203+
segmentations = _ribbon_AZ_postprocessing(
204+
predictions, vesicles, n_slices_exclude, n_ribbons, resolution, min_membrane_size
205+
)
201206

202207
if return_predictions:
203208
return segmentations, predictions

synapse_net/inference/postprocessing/membranes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,19 @@ def segment_membrane_distance_based(
8585
n_slices_exclude: int,
8686
max_distance: float,
8787
resolution: Optional[float] = None,
88+
min_size: int = 0,
8889
):
90+
"""Derive boundary segmentation from boundary predictions by selecting the fragment closest to the ribbon.
91+
92+
Args:
93+
boundary_prediction: Binary prediction for boundaries in the tomogram.
94+
reference_segmentation: The reference segmentation, typically of the ribbon.
95+
n_slices_exclude: The number of slices to exclude on the top / bottom
96+
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
97+
max_distance: The maximal distance from the ribbon to consider.
98+
resolution: The resolution / voxel size of the data.
99+
min_size: The minimal size of a boundary fragment to be included.
100+
"""
89101
assert boundary_prediction.shape == reference_segmentation.shape
90102

91103
original_shape = boundary_prediction.shape
@@ -95,6 +107,13 @@ def segment_membrane_distance_based(
95107
boundary_prediction = boundary_prediction[slice_mask]
96108
reference_segmentation = reference_segmentation[slice_mask]
97109

110+
if min_size > 0:
111+
boundary_prediction = label(boundary_prediction, block_shape=(32, 256, 256))
112+
ids, sizes = np.unique(boundary_prediction, return_counts=True)
113+
ids, sizes = ids[1:], sizes[1:]
114+
keep_ids = ids[sizes > min_size]
115+
boundary_prediction = np.isin(boundary_prediction, keep_ids)
116+
98117
# Get the unique objects in the reference segmentation.
99118
reference_ids = np.unique(reference_segmentation)
100119
assert reference_ids[0] == 0

synapse_net/inference/scalable_segmentation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def scalable_segmentation(
7979
prediction: Optional[ArrayLike] = None,
8080
verbose: bool = True,
8181
mask: Optional[ArrayLike] = None,
82+
devices: Optional[List[str]] = None,
8283
) -> None:
8384
"""Run segmentation based on a prediction with foreground and boundary channel.
8485
@@ -100,6 +101,8 @@ def scalable_segmentation(
100101
If given, this can be a numpy array, a zarr array, or similar
101102
If not given will be stored in a temporary n5 array.
102103
verbose: Whether to print timing information.
104+
devices: The devices for running prediction. If not given will use the GPU
105+
if available, otherwise the CPU.
103106
"""
104107
if mask is not None:
105108
raise NotImplementedError
@@ -133,5 +136,5 @@ def scalable_segmentation(
133136
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
134137

135138
# Run prediction and segmentation.
136-
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
139+
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose, devices=devices)
137140
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

synapse_net/inference/util.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def get_prediction(
125125
channels_to_standardize: Optional[List[int]] = None,
126126
mask: Optional[ArrayLike] = None,
127127
prediction: Optional[ArrayLike] = None,
128+
devices: Optional[List[str]] = None,
128129
) -> ArrayLike:
129130
"""Run prediction on a given volume.
130131
@@ -143,6 +144,8 @@ def get_prediction(
143144
the foreground region of the mask.
144145
prediction: An array like object for writing the prediction.
145146
If not given, the prediction will be computed in moemory.
147+
devices: The devices for running prediction. If not given will use the GPU
148+
if available, otherwise the CPU.
146149
147150
Returns:
148151
The predicted volume.
@@ -189,7 +192,7 @@ def get_prediction(
189192
# print(f"updated_tiling {updated_tiling}")
190193
prediction = get_prediction_torch_em(
191194
input_volume, updated_tiling, model_path, model, verbose, with_channels,
192-
mask=mask, prediction=prediction,
195+
mask=mask, prediction=prediction, devices=devices,
193196
)
194197

195198
return prediction
@@ -204,6 +207,7 @@ def get_prediction_torch_em(
204207
with_channels: bool = False,
205208
mask: Optional[ArrayLike] = None,
206209
prediction: Optional[ArrayLike] = None,
210+
devices: Optional[List[str]] = None,
207211
) -> np.ndarray:
208212
"""Run prediction using torch-em on a given volume.
209213
@@ -218,6 +222,8 @@ def get_prediction_torch_em(
218222
the foreground region of the mask.
219223
prediction: An array like object for writing the prediction.
220224
If not given, the prediction will be computed in moemory.
225+
devices: The devices for running prediction. If not given will use the GPU
226+
if available, otherwise the CPU.
221227
222228
Returns:
223229
The predicted volume.
@@ -227,14 +233,15 @@ def get_prediction_torch_em(
227233
halo = [tiling["halo"]["z"], tiling["halo"]["x"], tiling["halo"]["y"]]
228234

229235
t0 = time.time()
230-
device = "cuda" if torch.cuda.is_available() else "cpu"
236+
if devices is None:
237+
devices = ["cuda" if torch.cuda.is_available() else "cpu"]
231238

232239
# Suppress warning when loading the model.
233240
with warnings.catch_warnings():
234241
warnings.simplefilter("ignore")
235242
if model is None:
236243
if os.path.isdir(model_path): # Load the model from a torch_em checkpoint.
237-
model = torch_em.util.load_model(checkpoint=model_path, device=device)
244+
model = torch_em.util.load_model(checkpoint=model_path, device=devices[0])
238245
else: # Load the model directly from a serialized pytorch model.
239246
model = torch.load(model_path, weights_only=False)
240247

@@ -253,7 +260,7 @@ def get_prediction_torch_em(
253260

254261
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
255262
prediction = predict_with_halo(
256-
input_volume, model, gpu_ids=[device],
263+
input_volume, model, gpu_ids=devices,
257264
block_shape=block_shape, halo=halo,
258265
preprocess=preprocess, with_channels=with_channels, mask=mask,
259266
output=prediction,

synapse_net/tools/segmentation_widget.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def on_predict(self):
187187
# For these models we read out the 'Extra Segmentation' widget.
188188
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
189189
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
190-
kwargs = {"extra_segmentation": extra_seg}
190+
resolution = tuple(voxel_size[ax] for ax in "zyx")
191+
kwargs = {"extra_segmentation": extra_seg, "resolution": resolution, "min_membrane_size": 50_000}
191192
elif model_type == "cristae": # Cristae model expects 2 3D volumes
192193
kwargs = {
193194
"extra_segmentation": self._get_layer_selector_data(self.extra_seg_selector_name),

0 commit comments

Comments
 (0)