Skip to content

Commit 6c7bc12

Browse files
Fix scale in vesicle segmentation; implement slurm array jobs for parallelization
1 parent bab12d8 commit 6c7bc12

File tree

2 files changed

+62
-12
lines changed

2 files changed

+62
-12
lines changed

scripts/cryo/cryo-et-portal/process_tomograms_on_the_fly.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import argparse
22
import os
3+
import subprocess
4+
35
import cryoet_data_portal as cdp
6+
import numpy as np
47
import zarr
58

69
from ome_zarr.io import parse_url
@@ -9,12 +12,17 @@
912
from synapse_net.inference.vesicles import segment_vesicles
1013
from tqdm import tqdm
1114

15+
# OUTPUT_ROOT = ""
16+
OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/synaptic-reconstruction/portal"
17+
1218

13-
def get_tomograms(deposition_id, processing_type):
19+
def get_tomograms(deposition_id, processing_type, number_of_tomograms=None):
1420
client = cdp.Client()
1521
tomograms = cdp.Tomogram.find(
1622
client, [cdp.Tomogram.deposition_id == deposition_id, cdp.Tomogram.processing == processing_type]
1723
)
24+
if number_of_tomograms is not None:
25+
tomograms = tomograms[:number_of_tomograms]
1826
return tomograms
1927

2028

@@ -30,7 +38,7 @@ def write_ome_zarr(output_file, segmentation, voxel_size):
3038

3139

3240
def run_prediction(tomogram, deposition_id, processing_type):
33-
output_folder = os.path.join(f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))
41+
output_folder = os.path.join(OUTPUT_ROOT, f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))
3442
os.makedirs(output_folder, exist_ok=True)
3543

3644
output_file = os.path.join(output_folder, f"{tomogram.run.name}.zarr")
@@ -45,7 +53,8 @@ def run_prediction(tomogram, deposition_id, processing_type):
4553

4654
# Segment vesicles.
4755
model_path = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-portal-v3"
48-
segmentation = segment_vesicles(data, model_path=model_path)
56+
scale = (1.0 / 2.7,) * 3
57+
segmentation = segment_vesicles(data, model_path=model_path, scale=scale)
4958

5059
# Save the segmentation.
5160
write_ome_zarr(output_file, segmentation, voxel_size)
@@ -77,27 +86,60 @@ def check_result(tomogram, deposition_id, processing_type):
7786
napari.run()
7887

7988

89+
def _get_task_tomograms(tomograms, slurm_tasks, task_id):
90+
# TODO we could also filter already done tomos.
91+
inputs_to_tasks = np.array_split(tomograms, slurm_tasks)
92+
assert len(inputs_to_tasks) == slurm_tasks
93+
return inputs_to_tasks[task_id]
94+
95+
96+
def process_slurm(args, tomograms, deposition_id, processing_type):
97+
assert not args.check
98+
task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
99+
100+
if task_id is None: # We are not in the slurm task and submit the job.
101+
# Assemble the command for submitting a slurm array job.
102+
script_path = "process_tomograms_on_the_fly.sbatch"
103+
cmd = ["sbatch", "-a", f"0-{args.slurm_tasks-1}", script_path, "-s", str(args.slurm_tasks)]
104+
print("Submitting to slurm:")
105+
print(cmd)
106+
subprocess.run(cmd)
107+
else: # We are in the task.
108+
task_id = int(task_id)
109+
this_tomograms = _get_task_tomograms(tomograms, args.slurm_tasks, task_id)
110+
for tomogram in tqdm(this_tomograms, desc="Run prediction for tomograms on-the-fly"):
111+
run_prediction(tomogram, deposition_id, processing_type)
112+
113+
114+
def process_local(args, tomograms, deposition_id, processing_type):
115+
# Process each tomogram.
116+
print("Start processing", len(tomograms), "tomograms")
117+
for tomogram in tqdm(tomograms, desc="Run prediction for tomograms on-the-fly"):
118+
if args.check:
119+
check_result(tomogram, deposition_id, processing_type)
120+
else:
121+
run_prediction(tomogram, deposition_id, processing_type)
122+
123+
80124
def main():
81125
parser = argparse.ArgumentParser()
82126
# Whether to check the result with napari instead of running the prediction.
83127
parser.add_argument("-c", "--check", action="store_true")
128+
parser.add_argument("-n", "--number_of_tomograms", type=int, default=None)
129+
parser.add_argument("-s", "--slurm_tasks", type=int, default=None)
84130
args = parser.parse_args()
85131

86132
deposition_id = 10313
87133
processing_type = "denoised"
88134

89135
# Get all the (processed) tomogram ids in the deposition.
90-
tomograms = get_tomograms(deposition_id, processing_type)
136+
tomograms = get_tomograms(deposition_id, processing_type, args.number_of_tomograms)
91137

92-
# Process each tomogram.
93-
print("Start processing", len(tomograms), "tomograms")
94-
for tomogram in tqdm(tomograms, desc="Run prediction for tomograms on-the-fly"):
95-
if args.check:
96-
check_result(tomogram, deposition_id, processing_type)
97-
else:
98-
run_prediction(tomogram, deposition_id, processing_type)
138+
if args.slurm_tasks is None:
139+
process_local(args, tomograms, deposition_id, processing_type)
140+
else:
141+
process_slurm(args, tomograms, deposition_id, processing_type)
99142

100143

101-
# TODO segmented at wrong size, check voxel size!
102144
if __name__ == "__main__":
103145
main()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
#SBATCH --partition=grete:shared
3+
#SBATCH -G A100:1
4+
#SBATCH --time=00-08:00:00
5+
#SBATCH --nodes=1
6+
#SBATCH -c 12
7+
8+
/scratch-grete/usr/nimcpape/software/mamba/envs/sam/bin/python process_tomograms_on_the_fly.py $@

0 commit comments

Comments
 (0)