Skip to content

Commit d23d544

Browse files
Add code for tomogram on-the-fly processing (#117)
Add support for processing ome.zarr files on the fly
1 parent 964d6c3 commit d23d544

File tree

3 files changed

+165
-2
lines changed

3 files changed

+165
-2
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import argparse
2+
import os
3+
import subprocess
4+
5+
import cryoet_data_portal as cdp
6+
import numpy as np
7+
import zarr
8+
9+
from ome_zarr.io import parse_url
10+
from ome_zarr.writer import write_image
11+
from synapse_net.file_utils import read_data_from_cryo_et_portal_run
12+
from synapse_net.inference.vesicles import segment_vesicles
13+
from tqdm import tqdm
14+
15+
# OUTPUT_ROOT = ""
16+
OUTPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/synaptic-reconstruction/portal"
17+
18+
19+
def get_tomograms(deposition_id, processing_type, number_of_tomograms=None):
20+
client = cdp.Client()
21+
tomograms = cdp.Tomogram.find(
22+
client, [cdp.Tomogram.deposition_id == deposition_id, cdp.Tomogram.processing == processing_type]
23+
)
24+
if number_of_tomograms is not None:
25+
tomograms = tomograms[:number_of_tomograms]
26+
return tomograms
27+
28+
29+
def write_ome_zarr(output_file, segmentation, voxel_size, unit="nanometer"):
30+
store = parse_url(output_file, mode="w").store
31+
root = zarr.group(store=store)
32+
33+
scale = list(voxel_size.values())
34+
trafo = [
35+
[{"scale": scale, "type": "scale"}]
36+
]
37+
axes = [
38+
{"name": "z", "type": "space", "unit": unit},
39+
{"name": "y", "type": "space", "unit": unit},
40+
{"name": "x", "type": "space", "unit": unit},
41+
]
42+
write_image(segmentation, root, axes=axes, coordinate_transformations=trafo, scaler=None)
43+
44+
45+
def run_prediction(tomogram, deposition_id, processing_type):
46+
output_folder = os.path.join(OUTPUT_ROOT, f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))
47+
os.makedirs(output_folder, exist_ok=True)
48+
49+
output_file = os.path.join(output_folder, f"{tomogram.run.name}.zarr")
50+
# We don't need to do anything if this file is already processed.
51+
if os.path.exists(output_file):
52+
return
53+
54+
# Read tomogram data on the fly.
55+
data, voxel_size = read_data_from_cryo_et_portal_run(
56+
tomogram.run_id, processing_type=processing_type
57+
)
58+
59+
# Segment vesicles.
60+
model_path = "/mnt/lustre-emmy-hdd/projects/nim00007/models/synaptic-reconstruction/vesicle-DA-portal-v3"
61+
scale = (1.0 / 2.7,) * 3
62+
segmentation = segment_vesicles(data, model_path=model_path, scale=scale)
63+
64+
# Save the segmentation.
65+
write_ome_zarr(output_file, segmentation, voxel_size)
66+
67+
68+
# TODO download on lower scale
69+
def check_result(tomogram, deposition_id, processing_type):
70+
import napari
71+
72+
# Read tomogram data on the fly.
73+
print("Download data ...")
74+
data, voxel_size = read_data_from_cryo_et_portal_run(
75+
tomogram.run_id, processing_type=processing_type
76+
)
77+
78+
# Read the output file if it exists.
79+
output_folder = os.path.join(f"upload_CZCDP-{deposition_id}", str(tomogram.run.dataset_id))
80+
output_file = os.path.join(output_folder, f"{tomogram.run.name}.zarr")
81+
if os.path.exists(output_file):
82+
with zarr.open(output_file, "r") as f:
83+
segmentation = f["0"][:]
84+
else:
85+
segmentation = None
86+
87+
v = napari.Viewer()
88+
v.add_image(data)
89+
if segmentation is not None:
90+
v.add_labels(segmentation)
91+
napari.run()
92+
93+
94+
def _get_task_tomograms(tomograms, slurm_tasks, task_id):
95+
# TODO we could also filter already done tomos.
96+
inputs_to_tasks = np.array_split(tomograms, slurm_tasks)
97+
assert len(inputs_to_tasks) == slurm_tasks
98+
return inputs_to_tasks[task_id]
99+
100+
101+
def process_slurm(args, tomograms, deposition_id, processing_type):
102+
assert not args.check
103+
task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
104+
105+
if task_id is None: # We are not in the slurm task and submit the job.
106+
# Assemble the command for submitting a slurm array job.
107+
script_path = "process_tomograms_on_the_fly.sbatch"
108+
cmd = ["sbatch", "-a", f"0-{args.slurm_tasks-1}", script_path, "-s", str(args.slurm_tasks)]
109+
print("Submitting to slurm:")
110+
print(cmd)
111+
subprocess.run(cmd)
112+
else: # We are in the task.
113+
task_id = int(task_id)
114+
this_tomograms = _get_task_tomograms(tomograms, args.slurm_tasks, task_id)
115+
for tomogram in tqdm(this_tomograms, desc="Run prediction for tomograms on-the-fly"):
116+
run_prediction(tomogram, deposition_id, processing_type)
117+
118+
119+
def process_local(args, tomograms, deposition_id, processing_type):
120+
# Process each tomogram.
121+
print("Start processing", len(tomograms), "tomograms")
122+
for tomogram in tqdm(tomograms, desc="Run prediction for tomograms on-the-fly"):
123+
if args.check:
124+
check_result(tomogram, deposition_id, processing_type)
125+
else:
126+
run_prediction(tomogram, deposition_id, processing_type)
127+
128+
129+
def main():
130+
parser = argparse.ArgumentParser()
131+
# Whether to check the result with napari instead of running the prediction.
132+
parser.add_argument("-c", "--check", action="store_true")
133+
parser.add_argument("-n", "--number_of_tomograms", type=int, default=None)
134+
parser.add_argument("-s", "--slurm_tasks", type=int, default=None)
135+
args = parser.parse_args()
136+
137+
deposition_id = 10313
138+
processing_type = "denoised"
139+
140+
# Get all the (processed) tomogram ids in the deposition.
141+
tomograms = get_tomograms(deposition_id, processing_type, args.number_of_tomograms)
142+
143+
if args.slurm_tasks is None:
144+
process_local(args, tomograms, deposition_id, processing_type)
145+
else:
146+
process_slurm(args, tomograms, deposition_id, processing_type)
147+
148+
149+
if __name__ == "__main__":
150+
main()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/bin/bash
2+
#SBATCH --partition=grete:shared
3+
#SBATCH -G A100:1
4+
#SBATCH --time=00-08:00:00
5+
#SBATCH --nodes=1
6+
#SBATCH -c 12
7+
8+
/scratch-grete/usr/nimcpape/software/mamba/envs/sam/bin/python process_tomograms_on_the_fly.py $@

synapse_net/file_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,16 @@ def parse_s3_uri(uri):
142142
# Read the axis and transformation metadata for this dataset, to determine the voxel size.
143143
axes = [axis["name"] for axis in multiscales["axes"]]
144144
assert set(axes) == set("xyz")
145+
units = [axis.get("unit", "angstrom") for axis in multiscales["axes"]]
146+
assert all(unit in ("angstrom", "nanometer") for unit in units)
147+
145148
transformations = multiscales["datasets"][scale_level]["coordinateTransformations"]
146149
scale_transformation = [trafo["scale"] for trafo in transformations if trafo["type"] == "scale"][0]
147150

148-
# The voxel size is given in angstrom, we divide it by 10 to convert it to nanometer.
149-
voxel_size = {axis: scale / 10.0 for axis, scale in zip(axes, scale_transformation)}
151+
# Convert the given unit size to nanometer.
152+
# (It is typically given in angstrom, and we have to divide by a factor of 10).
153+
unit_factor = [10.0 if unit == "angstrom" else 1.0 for unit in units]
154+
voxel_size = {axis: scale / factor for axis, scale, factor in zip(axes, scale_transformation, unit_factor)}
150155

151156
# Get the internale path for the given scale and load the data.
152157
internal_path = multiscales["datasets"][scale_level]["path"]

0 commit comments

Comments
 (0)