Skip to content

Commit 75d60f9

Browse files
Update inference CLI so that it works with checkpoints
1 parent f76dbf7 commit 75d60f9

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

synapse_net/tools/cli.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import argparse
2+
import os
23
from functools import partial
34

45
import torch
6+
import torch_em
57
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
68
from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation
79
from ..inference.util import inference_helper, parse_tiling
@@ -155,7 +157,14 @@ def segmentation_cli():
155157
if args.checkpoint is None:
156158
model = get_model(args.model)
157159
else:
158-
model = torch.load(args.checkpoint, weights_only=False)
160+
checkpoint_path = args.checkpoint
161+
if checkpoint_path.endswith("best.pt"):
162+
checkpoint_path = os.path.split(checkpoint_path)[0]
163+
164+
if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint.
165+
model = torch_em.util.load_model(checkpoint=checkpoint_path)
166+
else:
167+
model = torch.load(checkpoint_path, weights_only=False)
159168
assert model is not None, f"The model from {args.checkpoint} could not be loaded."
160169

161170
is_2d = "2d" in args.model

0 commit comments

Comments
 (0)