We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8df45a1 commit d8919cfCopy full SHA for d8919cf
synapse_net/training/supervised_training.py
@@ -1,6 +1,6 @@
1
import os
2
from glob import glob
3
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
4
5
import torch
6
import torch_em
@@ -201,6 +201,7 @@ def supervised_training(
201
in_channels: int = 1,
202
out_channels: int = 2,
203
mask_channel: bool = False,
204
+ checkpoint_path: Optional[Union[os.PathLike, str]] = None,
205
**loader_kwargs,
206
):
207
"""Run supervised segmentation training.
@@ -303,7 +304,7 @@ def supervised_training(
303
304
loss=loss,
305
metric=metric,
306
)
- trainer.fit(n_iterations)
307
+ trainer.fit(n_iterations, load_from_checkpoint=checkpoint_path)
308
309
310
def _derive_key_from_files(files, key):
0 commit comments