Skip to content

Commit d8919cf

Browse files
committed
added parameter to load model from checkpoint
1 parent 8df45a1 commit d8919cf

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

synapse_net/training/supervised_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from glob import glob
3-
from typing import Optional, Tuple
3+
from typing import Optional, Tuple, Union
44

55
import torch
66
import torch_em
@@ -201,6 +201,7 @@ def supervised_training(
201201
in_channels: int = 1,
202202
out_channels: int = 2,
203203
mask_channel: bool = False,
204+
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
204205
**loader_kwargs,
205206
):
206207
"""Run supervised segmentation training.
@@ -303,7 +304,7 @@ def supervised_training(
303304
loss=loss,
304305
metric=metric,
305306
)
306-
trainer.fit(n_iterations)
307+
trainer.fit(n_iterations, load_from_checkpoint=checkpoint_path)
307308

308309

309310
def _derive_key_from_files(files, key):

0 commit comments

Comments
 (0)