Skip to content

Commit 3fecc1f

Browse files
committed
added torch_em load_model to supervised training
1 parent 8df45a1 commit 3fecc1f

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

synapse_net/training/supervised_training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def supervised_training(
201201
in_channels: int = 1,
202202
out_channels: int = 2,
203203
mask_channel: bool = False,
204+
checkpoint_path: Optional[str] = None,
204205
**loader_kwargs,
205206
):
206207
"""Run supervised segmentation training.
@@ -265,6 +266,9 @@ def supervised_training(
265266
model = get_2d_model(out_channels=out_channels, in_channels=in_channels)
266267
else:
267268
model = get_3d_model(out_channels=out_channels, in_channels=in_channels)
269+
270+
if checkpoint_path:
271+
model = torch_em.util.load_model(checkpoint=checkpoint_path)
268272

269273
loss, metric = None, None
270274
# No ignore label -> we can use default loss.

0 commit comments

Comments
 (0)