Skip to content

Commit a51bf2a

Browse files
Start implementation of CLI functionality
1 parent 3c5e046 commit a51bf2a

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

flamingo_tools/segmentation/cli.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""private
2+
"""
3+
import argparse
4+
5+
from .unet_prediction import run_unet_prediction
6+
from .synapse_detection import marker_detection
7+
from ..model_utils import get_model_path
8+
9+
10+
def _get_model_path(model_type, checkpoint_path=None):
11+
if checkpoint_path is None:
12+
model_path = get_model_path(model_type)
13+
else:
14+
model_path = ... # TODO
15+
return model_path
16+
17+
18+
def run_segmentation():
19+
"""private
20+
"""
21+
parser = argparse.ArgumentParser(description="")
22+
parser.add_argument("-i", "--input_path", required=True, help="The path to the input data.")
23+
parser.add_argument("-k", "--input_key", required=True, help="The key to the input data.")
24+
parser.add_argument("-o", "--output_folder", required=True)
25+
parser.add_argument("-m", "--model_type", required=True)
26+
parser.add_argument("-c", "--checkpoint_path")
27+
parser.add_argument("--min_size", type=int, default=250)
28+
# TODO other stuff
29+
args = parser.parse_args()
30+
31+
segmentation_models = ["SGN", "IHC", "SGN-lowres", "IHC-lowres"]
32+
if args.model_type not in segmentation_models:
33+
raise ValueError
34+
model_path = _get_model_path(args.model_type, args.checkpoint_path)
35+
run_unet_prediction(
36+
input_path=args.input_path, input_key=args.input_key,
37+
output_folder=args.output_folder, model_path=model_path,
38+
min_size=args.min_size,
39+
)
40+
41+
42+
def run_detection():
43+
"""private
44+
"""
45+
parser = argparse.ArgumentParser()
46+
parser.add_argument("-m", "--model_type", default="Synapses")
47+
args = parser.parse_args()
48+
detection_models = ["Synapses"]
49+
if args.model_type not in detection_models:
50+
raise ValueError
51+
model_path = _get_model_path(args.model_type, args.checkpoint_path)
52+
# TODO
53+
marker_detection(
54+
input_path=args.input_path, input_key=args.input_key,
55+
output_folder=args.output_folder, model_path=model_path,
56+
)

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
license="MIT",
1111
entry_points={
1212
"console_scripts": [
13-
"convert_flamingo = flamingo_tools.data_conversion:convert_lightsheet_to_bdv_cli"
13+
"flamingo_tools.convert_data = flamingo_tools.data_conversion:convert_lightsheet_to_bdv_cli",
14+
"flamingo_tools.run_segmentation = flamingo_tools.segmentation.cli:run_segmentation",
15+
"flamingo_tools.run_detection = flamingo_tools.segmentation.cli:run_detection",
16+
# TODO: MoBIE conversion, tonotopic mapping
1417
],
1518
"napari.manifest": [
1619
"cochlea_net = flamingo_tools:napari.yaml",

0 commit comments

Comments
 (0)