Skip to content

Commit 45ff798

Browse files
authored
Merge pull request #51 from janelia-cellmap/docs
Tests + fix bugs + scriptable
2 parents e32ac5f + 2581dbb commit 45ff798

File tree

23 files changed

+385
-251
lines changed

23 files changed

+385
-251
lines changed

cellmap_flow/blockwise/blockwise_processor.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,12 @@
1616
from funlib.persistence import prepare_ds, open_ds, Array
1717
from pathlib import Path
1818

19-
# from cellmap_flow.globals import Flow
20-
import cellmap_flow.globals as g
19+
from cellmap_flow.globals import g
2120
from cellmap_flow.utils.web_utils import encode_to_str, decode_to_json
2221

2322
logger = logging.getLogger(__name__)
2423

2524

26-
def get_output_dtype(model_output):
27-
p_dtype = model_output
28-
# g = Flow()
29-
if len(g.postprocess) > 0:
30-
for postprocess in g.postprocess[::-1]:
31-
if postprocess.dtype:
32-
p_dtype = postprocess.dtype
33-
break
34-
return p_dtype
35-
36-
3725
def get_process_dataset(json_data: str):
3826
logger.error(f"json data: {json_data}")
3927
input_norm_fns = get_normalizations(json_data[INPUT_NORM_DICT_KEY])
@@ -60,7 +48,6 @@ def __init__(self, yaml_config: str, create=True):
6048
self.output_path = self.config["output_path"]
6149
self.output_path = Path(self.output_path)
6250

63-
6451
output_channels = None
6552
if "output_channels" in self.config:
6653
output_channels = self.config["output_channels"].split(",")
@@ -112,7 +99,7 @@ def __init__(self, yaml_config: str, create=True):
11299
else:
113100
self.output_channels = self.channels
114101

115-
self.dtype = get_output_dtype(self.model_config.output_dtype)
102+
self.dtype = g.get_output_dtype(self.model_config.output_dtype)
116103

117104
# g = Flow()
118105

@@ -131,7 +118,6 @@ def __init__(self, yaml_config: str, create=True):
131118
)
132119
self.output_arrays = []
133120

134-
135121
output_shape = (
136122
np.array(self.idi_raw.shape)
137123
* np.array(self.input_voxel_size)
@@ -144,7 +130,7 @@ def __init__(self, yaml_config: str, create=True):
144130
for channel in self.output_channels:
145131
if create:
146132
array = prepare_ds(
147-
DirectoryStore(self.output_path / channel/"s0"),
133+
DirectoryStore(self.output_path / channel / "s0"),
148134
output_shape,
149135
dtype=self.dtype,
150136
chunk_shape=self.block_shape,
@@ -156,7 +142,7 @@ def __init__(self, yaml_config: str, create=True):
156142
else:
157143
try:
158144
array = open_ds(
159-
DirectoryStore(self.output_path / channel/"s0"),
145+
DirectoryStore(self.output_path / channel / "s0"),
160146
"a",
161147
)
162148
except Exception as e:
@@ -247,7 +233,7 @@ def run(self):
247233
import subprocess
248234

249235

250-
def spawn_worker(name, yaml_config,charge_group,queue,ncpu=12):
236+
def spawn_worker(name, yaml_config, charge_group, queue, ncpu=12):
251237
def run_worker():
252238
subprocess.run(
253239
[
@@ -270,7 +256,7 @@ def run_worker():
270256
"run",
271257
"-y",
272258
f"{yaml_config}",
273-
"--client"
259+
"--client",
274260
]
275261
)
276262

cellmap_flow/cli/cli.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def script(script_path, data_path, queue, charge_group):
149149
help="The chargeback group to use when submitting",
150150
default=None,
151151
)
152-
153152
def bioimage(model_path, data_path, edge_length_to_process, queue, charge_group):
154153
command = f"{SERVER_COMMAND} bioimage -m {model_path} -d {data_path} -e {edge_length_to_process}"
155154
base_name = model_path.split("/")[-1].split(".")[0]
@@ -188,7 +187,6 @@ def cellmap_model(config_folder, name, data_path, queue, charge_group):
188187
run(command, data_path, queue, charge_group, name)
189188

190189

191-
192190
@cli.command()
193191
@click.option(
194192
"--script_path",

cellmap_flow/cli/multiple_cli.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import logging
99
from cellmap_flow.utils.bsub_utils import start_hosts, SERVER_COMMAND
1010
from cellmap_flow.utils.neuroglancer_utils import generate_neuroglancer_url
11-
import cellmap_flow.globals as g
11+
from cellmap_flow.globals import g
1212

1313

1414
data_args = ["-d", "--data-path"]
@@ -158,11 +158,9 @@ def main():
158158
else:
159159
j += 1
160160
if not config_folder:
161-
logger.error(
162-
"Missing -c for --celmmap-model sub-command."
163-
)
161+
logger.error("Missing -c for --cellmap-model sub-command.")
164162
sys.exit(1)
165-
models.append(CellMapModelConfig(config_folder, name=name,scale=scale))
163+
models.append(CellMapModelConfig(config_folder, name=name, scale=scale))
166164
i = j
167165
continue
168166

cellmap_flow/cli/multiple_yaml_cli.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from cellmap_flow.utils.bsub_utils import start_hosts, SERVER_COMMAND
55
from cellmap_flow.utils.neuroglancer_utils import generate_neuroglancer_url
66
from cellmap_flow.utils.config_utils import load_config, build_models
7-
import cellmap_flow.globals as g
7+
from cellmap_flow.globals import g
8+
89

910
logger = logging.getLogger(__name__)
1011

@@ -24,10 +25,7 @@ def run_multiple(models, dataset_path, charge_group, queue):
2425

2526
command = f"{SERVER_COMMAND} {model.command} -d {current_data_path}"
2627
start_hosts(
27-
command,
28-
job_name=model.name,
29-
queue=queue,
30-
charge_group=charge_group
28+
command, job_name=model.name, queue=queue, charge_group=charge_group
3129
)
3230

3331
generate_neuroglancer_url(dataset_path)
@@ -49,14 +47,14 @@ def main():
4947
config_path = sys.argv[1]
5048
config = load_config(config_path)
5149

52-
data_path = config['data_path']
53-
charge_group = config['charge_group']
54-
queue = config['queue']
50+
data_path = config["data_path"]
51+
charge_group = config["charge_group"]
52+
queue = config["queue"]
5553

5654
print("Data path:", data_path)
5755

5856
# Build model configuration objects
59-
models = build_models(config['models'])
57+
models = build_models(config["models"])
6058

6159
# For debugging, print each model config
6260
for model in models:

cellmap_flow/dashboard/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
POSTPROCESS_DICT_KEY,
2323
)
2424
from cellmap_flow.models.run import update_run_models
25-
import cellmap_flow.globals as g
25+
from cellmap_flow.globals import g
2626
import numpy as np
2727
import time
2828

cellmap_flow/globals.py

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,128 @@
1-
jobs = []
1+
from cellmap_flow.norm.input_normalize import MinMaxNormalizer
2+
from cellmap_flow.post.postprocessors import DefaultPostprocessor
3+
from cellmap_flow.models.model_yaml import load_model_paths
4+
import os
5+
import threading
6+
import numpy as np
27

3-
models_config = []
48

5-
servers = []
9+
class Flow:
10+
_instance = None
611

7-
raw = None
12+
def __new__(cls):
13+
if cls._instance is None:
14+
cls._instance = super(Flow, cls).__new__(cls)
15+
cls._instance.jobs = []
16+
cls._instance.models_config = []
17+
cls._instance.servers = []
18+
cls._instance.raw = None
19+
cls._instance.input_norms = [] # or [MinMaxNormalizer(0, 255)]
20+
cls._instance.postprocess = []
21+
cls._instance.viewer = None
22+
cls._instance.dataset_path = None
23+
cls._instance.model_catalog = {}
24+
# Uncomment and adjust if you want to load the model catalog:
25+
# cls._instance.model_catalog = load_model_paths(
26+
# os.path.normpath(os.path.join(os.path.dirname(__file__), os.pardir, "models", "models.yaml"))
27+
# )
28+
cls._instance.queue = "gpu_h100"
29+
cls._instance.charge_group = "cellmap"
30+
cls._instance.neuroglancer_thread = None
31+
return cls._instance
832

33+
def to_dict(self):
34+
return self.__dict__.items()
935

10-
from cellmap_flow.norm.input_normalize import MinMaxNormalizer
11-
from cellmap_flow.post.postprocessors import DefaultPostprocessor
36+
def __repr__(self):
37+
return f"Flow({self.__dict__})"
1238

13-
# input_norms = [MinMaxNormalizer()]
14-
# postprocess = [DefaultPostprocessor(0,200,0,1)]
39+
def __str__(self):
40+
return f"Flow({self.__dict__})"
1541

16-
input_norms = []
17-
postprocess = []
18-
viewer = None
42+
def get_output_dtype(self):
43+
dtype = np.float32
1944

20-
dataset_path = None
45+
if len(self.input_norms) > 0:
46+
for norm in self.input_norms[::-1]:
47+
if norm.dtype:
48+
dtype = norm.dtype
49+
break
2150

51+
if len(self.postprocess) > 0:
52+
for postprocess in self.postprocess[::-1]:
53+
if postprocess.dtype:
54+
dtype = postprocess.dtype
55+
break
2256

23-
from cellmap_flow.models.model_yaml import load_model_paths
57+
return dtype
2458

25-
import os
26-
model_catalog = {}
27-
# model_catalog = load_model_paths(
28-
# os.path.normpath(
29-
# os.path.join(os.path.dirname(__file__), os.pardir, "models", "models.yaml")
30-
# )
31-
# )
32-
33-
queue = "gpu_h100"
34-
charge_group = "cellmap"
59+
@classmethod
60+
def run(
61+
cls,
62+
zarr_path,
63+
model_configs,
64+
queue="gpu_h100",
65+
charge_group="cellmap",
66+
input_normalizers=None,
67+
post_processors=None,
68+
):
69+
70+
from cellmap_flow.utils.bsub_utils import start_hosts, SERVER_COMMAND
71+
from cellmap_flow.utils.neuroglancer_utils import generate_neuroglancer_url
72+
73+
if input_normalizers is None:
74+
input_normalizers = []
75+
if post_processors is None:
76+
post_processors = []
77+
78+
# Get the singleton instance (creates one if it doesn't exist)
79+
instance = cls()
80+
instance.queue = queue
81+
instance.charge_group = charge_group
82+
instance.dataset_path = zarr_path
83+
instance.input_norms = input_normalizers
84+
instance.postprocess = post_processors
85+
instance.models_config = model_configs
86+
instance.neuroglancer_thread = None
87+
88+
threads = []
89+
90+
for model_config in instance.models_config:
91+
model_command = model_config.command
92+
command = f"{SERVER_COMMAND} {model_command} -d {instance.dataset_path}"
93+
print(f"Starting server with command: {command}")
94+
thread = threading.Thread(
95+
target=start_hosts,
96+
args=(command, queue, charge_group, model_config.name),
97+
)
98+
thread.start()
99+
threads.append(thread)
100+
101+
for thread in threads:
102+
thread.join()
103+
104+
instance.neuroglancer_thread = threading.Thread(
105+
target=generate_neuroglancer_url, args=(instance.dataset_path,)
106+
)
107+
instance.neuroglancer_thread.start()
108+
# Optionally wait for the neuroglancer thread:
109+
# instance.neuroglancer_thread.join()
110+
111+
print(f"*****Neuroglancer URL: {instance.dataset_path}")
112+
113+
@classmethod
114+
def stop(cls):
115+
instance = cls()
116+
for job in instance.jobs:
117+
print(f"Killing job {job.job_id}")
118+
job.kill()
119+
if instance.neuroglancer_thread is not None:
120+
instance.neuroglancer_thread = None
121+
instance.jobs = []
122+
123+
@classmethod
124+
def delete(cls):
125+
cls._instance = None
126+
127+
128+
g = Flow()

cellmap_flow/inferencer.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,16 @@
33
import torch
44
from funlib.geometry import Coordinate
55
import logging
6-
from cellmap_flow.utils.data import (
7-
ModelConfig,
8-
BioModelConfig,
9-
DaCapoModelConfig,
10-
ScriptModelConfig,
11-
CellMapModelConfig,
12-
)
13-
import cellmap_flow.globals as g
14-
import neuroglancer
15-
from scipy import spatial
6+
from cellmap_flow.utils.data import ModelConfig
7+
8+
from cellmap_flow.globals import g
169

1710
logger = logging.getLogger(__name__)
1811

1912

2013
def apply_postprocess(data, **kwargs):
2114
for pross in g.postprocess:
15+
# logger.error(f"applying postprocess: {pross}")
2216
data = pross(data, **kwargs)
2317
return data
2418

@@ -93,23 +87,13 @@ def optimize_model(self):
9387
self.model_config.config.model.eval()
9488

9589
def process_chunk(self, idi, roi):
96-
# if isinstance(self.model_config, BioModelConfig):
97-
# return self.process_chunk_bioimagezoo(idi, roi)
98-
if (
99-
isinstance(self.model_config, DaCapoModelConfig)
100-
or isinstance(self.model_config, ScriptModelConfig)
101-
or isinstance(self.model_config, BioModelConfig)
102-
or isinstance(self.model_config, CellMapModelConfig)
90+
# check if process_chunk is in self.config
91+
if getattr(self.model_config.config, "process_chunk", None) and callable(
92+
self.model_config.config.process_chunk
10393
):
104-
# check if process_chunk is in self.config
105-
if getattr(self.model_config.config, "process_chunk", None) and callable(
106-
self.model_config.config.process_chunk
107-
):
108-
result = self.model_config.config.process_chunk(idi, roi)
109-
else:
110-
result = self.process_chunk_basic(idi, roi)
94+
result = self.model_config.config.process_chunk(idi, roi)
11195
else:
112-
raise ValueError(f"Invalid model config type {type(self.model_config)}")
96+
result = self.process_chunk_basic(idi, roi)
11397

11498
postprocessed = apply_postprocess(
11599
result,
@@ -131,4 +115,3 @@ def process_chunk_basic(self, idi, roi):
131115
use_half_prediction=self.use_half_prediction,
132116
)
133117
return result
134-

cellmap_flow/models/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import cellmap_flow.globals as g
1+
from cellmap_flow.globals import g
22

33
from cellmap_flow.utils.bsub_utils import start_hosts, SERVER_COMMAND
44
from cellmap_flow.utils.web_utils import (

0 commit comments

Comments
 (0)