Skip to content

Commit 21f9850

Browse files
authored
Add prepocess to pred func (#149)
* added torch_em load_model to supervised training * added docs for chekcpoint_path * added preprocess to predcition function so that input for predict_with_halo can be prepocessed per tile
1 parent e362fe0 commit 21f9850

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

synapse_net/inference/mitochondria.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Dict, List, Optional, Tuple, Union
2+
from typing import Callable, Dict, List, Optional, Tuple, Union
33

44
import elf.parallel as parallel
55
import numpy as np
@@ -65,6 +65,7 @@ def segment_mitochondria(
6565
ws_halo: Tuple[int, ...] = (48, 48, 48),
6666
boundary_threshold: float = 0.25,
6767
area_threshold: int = 5000,
68+
preprocess: Callable = None,
6869
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
6970
"""Segment mitochondria in an input volume.
7071
@@ -99,7 +100,10 @@ def segment_mitochondria(
99100
# Rescale the mask if it was given and run prediction.
100101
if mask is not None:
101102
mask = scaler.scale_input(mask, is_segmentation=True)
102-
pred = get_prediction(input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose)
103+
pred = get_prediction(
104+
input_volume, model_path=model_path, model=model, tiling=tiling, mask=mask, verbose=verbose,
105+
preprocess=preprocess
106+
)
103107

104108
# Run segmentation and rescale the result if necessary.
105109
foreground, boundaries = pred[:2]

synapse_net/inference/util.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def get_prediction(
126126
mask: Optional[ArrayLike] = None,
127127
prediction: Optional[ArrayLike] = None,
128128
devices: Optional[List[str]] = None,
129+
preprocess: Optional[callable] = None,
129130
) -> ArrayLike:
130131
"""Run prediction on a given volume.
131132
@@ -192,7 +193,7 @@ def get_prediction(
192193
# print(f"updated_tiling {updated_tiling}")
193194
prediction = get_prediction_torch_em(
194195
input_volume, updated_tiling, model_path, model, verbose, with_channels,
195-
mask=mask, prediction=prediction, devices=devices,
196+
mask=mask, prediction=prediction, devices=devices, preprocess=preprocess,
196197
)
197198

198199
return prediction
@@ -208,6 +209,7 @@ def get_prediction_torch_em(
208209
mask: Optional[ArrayLike] = None,
209210
prediction: Optional[ArrayLike] = None,
210211
devices: Optional[List[str]] = None,
212+
preprocess: Optional[callable] = None,
211213
) -> np.ndarray:
212214
"""Run prediction using torch-em on a given volume.
213215
@@ -258,7 +260,10 @@ def get_prediction_torch_em(
258260
print("Run prediction with mask.")
259261
mask = mask.astype("bool")
260262

261-
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
263+
if preprocess is None:
264+
preprocess = None if isinstance(input_volume, np.ndarray) else torch_em.transform.raw.standardize
265+
else:
266+
preprocess = preprocess
262267
prediction = predict_with_halo(
263268
input_volume, model, gpu_ids=devices,
264269
block_shape=block_shape, halo=halo,

synapse_net/training/supervised_training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def supervised_training(
201201
in_channels: int = 1,
202202
out_channels: int = 2,
203203
mask_channel: bool = False,
204+
checkpoint_path: Optional[str] = None,
204205
**loader_kwargs,
205206
):
206207
"""Run supervised segmentation training.
@@ -243,6 +244,7 @@ def supervised_training(
243244
out_channels: The number of output channels of the UNet.
244245
mask_channel: Whether the last channels in the labels should be used for masking the loss.
245246
This can be used to implement more complex masking operations and is not compatible with `ignore_label`.
247+
checkpoint_path: Path to the directory where 'best.pt' resides; continue training this model.
246248
loader_kwargs: Additional keyword arguments for the dataloader.
247249
"""
248250
train_loader = get_supervised_loader(train_paths, raw_key, label_key, patch_shape, batch_size,
@@ -265,6 +267,9 @@ def supervised_training(
265267
model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
266268
else:
267269
model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
270+
271+
if checkpoint_path:
272+
model = torch_em.util.load_model(checkpoint=checkpoint_path)
268273

269274
loss, metric = None, None
270275
# No ignore label -> we can use default loss.

0 commit comments

Comments
 (0)