11import argparse
22import os
3+ import subprocess
4+
35import cryoet_data_portal as cdp
6+ import numpy as np
47import zarr
58
69from ome_zarr .io import parse_url
912from synapse_net .inference .vesicles import segment_vesicles
1013from tqdm import tqdm
1114
15+ # OUTPUT_ROOT = ""
16+ OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/synaptic-reconstruction/portal"
17+
1218
13- def get_tomograms (deposition_id , processing_type ):
19+ def get_tomograms (deposition_id , processing_type , number_of_tomograms = None ):
1420 client = cdp .Client ()
1521 tomograms = cdp .Tomogram .find (
1622 client , [cdp .Tomogram .deposition_id == deposition_id , cdp .Tomogram .processing == processing_type ]
1723 )
24+ if number_of_tomograms is not None :
25+ tomograms = tomograms [:number_of_tomograms ]
1826 return tomograms
1927
2028
@@ -30,7 +38,7 @@ def write_ome_zarr(output_file, segmentation, voxel_size):
3038
3139
3240def run_prediction (tomogram , deposition_id , processing_type ):
33- output_folder = os .path .join (f"upload_CZCDP-{ deposition_id } " , str (tomogram .run .dataset_id ))
41+ output_folder = os .path .join (OUTPUT_ROOT , f"upload_CZCDP-{ deposition_id } " , str (tomogram .run .dataset_id ))
3442 os .makedirs (output_folder , exist_ok = True )
3543
3644 output_file = os .path .join (output_folder , f"{ tomogram .run .name } .zarr" )
@@ -45,7 +53,8 @@ def run_prediction(tomogram, deposition_id, processing_type):
4553
4654 # Segment vesicles.
4755 model_path = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-portal-v3"
48- segmentation = segment_vesicles (data , model_path = model_path )
56+ scale = (1.0 / 2.7 ,) * 3
57+ segmentation = segment_vesicles (data , model_path = model_path , scale = scale )
4958
5059 # Save the segmentation.
5160 write_ome_zarr (output_file , segmentation , voxel_size )
@@ -77,27 +86,60 @@ def check_result(tomogram, deposition_id, processing_type):
7786 napari .run ()
7887
7988
89+ def _get_task_tomograms (tomograms , slurm_tasks , task_id ):
90+ # TODO we could also filter already done tomos.
91+ inputs_to_tasks = np .array_split (tomograms , slurm_tasks )
92+ assert len (inputs_to_tasks ) == slurm_tasks
93+ return inputs_to_tasks [task_id ]
94+
95+
96+ def process_slurm (args , tomograms , deposition_id , processing_type ):
97+ assert not args .check
98+ task_id = os .environ .get ("SLURM_ARRAY_TASK_ID" )
99+
100+ if task_id is None : # We are not in the slurm task and submit the job.
101+ # Assemble the command for submitting a slurm array job.
102+ script_path = "process_tomograms_on_the_fly.sbatch"
103+ cmd = ["sbatch" , "-a" , f"0-{ args .slurm_tasks - 1 } " , script_path , "-s" , str (args .slurm_tasks )]
104+ print ("Submitting to slurm:" )
105+ print (cmd )
106+ subprocess .run (cmd )
107+ else : # We are in the task.
108+ task_id = int (task_id )
109+ this_tomograms = _get_task_tomograms (tomograms , args .slurm_tasks , task_id )
110+ for tomogram in tqdm (this_tomograms , desc = "Run prediction for tomograms on-the-fly" ):
111+ run_prediction (tomogram , deposition_id , processing_type )
112+
113+
114+ def process_local (args , tomograms , deposition_id , processing_type ):
115+ # Process each tomogram.
116+ print ("Start processing" , len (tomograms ), "tomograms" )
117+ for tomogram in tqdm (tomograms , desc = "Run prediction for tomograms on-the-fly" ):
118+ if args .check :
119+ check_result (tomogram , deposition_id , processing_type )
120+ else :
121+ run_prediction (tomogram , deposition_id , processing_type )
122+
123+
80124def main ():
81125 parser = argparse .ArgumentParser ()
82126 # Whether to check the result with napari instead of running the prediction.
83127 parser .add_argument ("-c" , "--check" , action = "store_true" )
128+ parser .add_argument ("-n" , "--number_of_tomograms" , type = int , default = None )
129+ parser .add_argument ("-s" , "--slurm_tasks" , type = int , default = None )
84130 args = parser .parse_args ()
85131
86132 deposition_id = 10313
87133 processing_type = "denoised"
88134
89135 # Get all the (processed) tomogram ids in the deposition.
90- tomograms = get_tomograms (deposition_id , processing_type )
136+ tomograms = get_tomograms (deposition_id , processing_type , args . number_of_tomograms )
91137
92- # Process each tomogram.
93- print ("Start processing" , len (tomograms ), "tomograms" )
94- for tomogram in tqdm (tomograms , desc = "Run prediction for tomograms on-the-fly" ):
95- if args .check :
96- check_result (tomogram , deposition_id , processing_type )
97- else :
98- run_prediction (tomogram , deposition_id , processing_type )
138+ if args .slurm_tasks is None :
139+ process_local (args , tomograms , deposition_id , processing_type )
140+ else :
141+ process_slurm (args , tomograms , deposition_id , processing_type )
99142
100143
101- # TODO segmented at wrong size, check voxel size!
102144if __name__ == "__main__" :
103145 main ()
0 commit comments