Skip to content

Commit 586623c

Browse files
committed
added preprocess to predcition function so that input for predict_with_halo can be prepocessed per tile
1 parent 6263ed3 commit 586623c

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

synapse_net/inference/mitochondria.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Dict, List, Optional, Tuple, Union
2+
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import elf.parallel as parallel
55
import numpy as np
@@ -65,6 +65,7 @@ def segment_mitochondria(
6565
ws_halo: Tuple[int, ...] = (48, 48, 48),
6666
boundary_threshold: float = 0.25,
6767
area_threshold: int = 5000,
68+
preprocess: Callable = None,
6869
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
6970
"""Segment mitochondria in an input volume.
7071
@@ -99,7 +100,10 @@ def segment_mitochondria(
99100
# Rescale the mask if it was given and run prediction.
100101
if mask is not None:
101102
mask = scaler.scale_input(mask, is_segmentation=True)
102-
pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
103+
pred = get_prediction(
104+
input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose,
105+
preprocess=preprocess
106+
)
103107

104108
# Run segmentation and rescale the result if necessary.
105109
foreground, boundaries = pred[:2]

synapse_net/inference/util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def get_prediction(
126126
mask: Optional[ArrayLike] = None,
127127
prediction: Optional[ArrayLike] = None,
128128
devices: Optional[List[str]] = None,
129+
preprocess: Optional[callable] = None,
129130
) -> ArrayLike:
130131
"""Run prediction on a given volume.
131132
@@ -192,7 +193,7 @@ def get_prediction(
192193
# print(f"updated_tiling {updated_tiling}")
193194
prediction = get_prediction_torch_em(
194195
input_volume, updated_tiling, model_path, model, verbose, with_channels,
195-
mask=mask, prediction=prediction, devices=devices,
196+
mask=mask, prediction=prediction, devices=devices, preprocess=preprocess,
196197
)
197198

198199
return prediction
@@ -208,6 +209,7 @@ def get_prediction_torch_em(
208209
mask: Optional[ArrayLike] = None,
209210
prediction: Optional[ArrayLike] = None,
210211
devices: Optional[List[str]] = None,
212+
preprocess: Optional[callable] = None,
211213
) -> np.ndarray:
212214
"""Run prediction using torch-em on a given volume.
213215
@@ -258,7 +260,10 @@ def get_prediction_torch_em(
258260
print("Run prediction with mask.")
259261
mask = mask.astype("bool")
260262

261-
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
263+
if preprocess is None:
264+
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
265+
else:
266+
preprocess = preprocess
262267
prediction = predict_with_halo(
263268
input_volume, model, gpu_ids=devices,
264269
block_shape=block_shape, halo=halo,

0 commit comments

Comments
 (0)