Skip to content

Commit caba84f

Browse files
Extend support for data download from CryoET portal
1 parent 2d7cf7f commit caba84f

File tree

2 files changed

+121
-11
lines changed

2 files changed

+121
-11
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
import os
3+
4+
from synapse_net.file_utils import read_data_from_cryo_et_portal_run
5+
from tqdm import tqdm
6+
7+
8+
def download_tomogram_list(run_ids, output_root):
9+
print("Downloading", len(run_ids), "tomograms")
10+
os.makedirs(output_root, exist_ok=True)
11+
for run_id in tqdm(run_ids):
12+
output_path = os.path.join(output_root, f"{run_id}.mrc")
13+
data, voxel_size = read_data_from_cryo_et_portal_run(
14+
run_id, use_zarr_format=False, output_path=output_path, id_field="id",
15+
)
16+
if data is None:
17+
print("Did not find a tomogram for", run_id)
18+
19+
20+
def download_tomograms_for_da():
21+
with open("./list_for_da.json") as f:
22+
run_ids = json.load(f)
23+
output_root = "/scratch-grete/projects/nim00007/cryo-et/from_portal/for_domain_adaptation"
24+
download_tomogram_list(run_ids, output_root)
25+
26+
27+
def download_tomograms_for_eval():
28+
with open("./list_for_eval.json") as f:
29+
run_ids = json.load(f)
30+
download_tomogram_list(run_ids)
31+
32+
33+
def main():
34+
# download_tomograms_for_eval()
35+
download_tomograms_for_da()
36+
37+
38+
if __name__ == "__main__":
39+
main()

synapse_net/file_utils.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
except ImportError:
1616
zarr = None
1717

18+
try:
19+
import s3fs
20+
except ImportError:
21+
s3fs = None
22+
1823

1924
def get_cache_dir() -> str:
2025
"""Get the cache directory of synapse net.
@@ -100,12 +105,13 @@ def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]:
100105
return data, voxel_size
101106

102107

103-
def read_ome_zarr(uri: str, scale_level: int = 0) -> Tuple[np.ndarray, Dict[str, float]]:
108+
def read_ome_zarr(uri: str, scale_level: int = 0, fs=None) -> Tuple[np.ndarray, Dict[str, float]]:
104109
"""Read data and voxel size from an ome.zarr file.
105110
106111
Args:
107112
uri: Path or url to the ome.zarr file.
108113
scale_level: The level of the multi-scale image pyramid to load.
114+
fs: S3 filesystem to use for initializing the store.
109115
110116
Returns:
111117
The data read from the file.
@@ -114,31 +120,96 @@ def read_ome_zarr(uri: str, scale_level: int = 0) -> Tuple[np.ndarray, Dict[str,
114120
if zarr is None:
115121
raise RuntimeError("The zarr library is required to read ome.zarr files.")
116122

117-
# TODO handle URLs / make sure that zarr parses it correctly.
118-
with zarr.open(uri, "r") as f:
123+
def parse_s3_uri(uri):
124+
return uri.lstrip("s3://")
125+
126+
if uri.startswith("s3"):
127+
if fs is None:
128+
fs = s3fs.S3FileSystem(anon=True)
129+
s3_uri = parse_s3_uri(uri)
130+
store = s3fs.S3Map(root=s3_uri, s3=fs, check=False)
131+
elif fs is not None:
132+
s3_uri = parse_s3_uri(uri)
133+
store = s3fs.S3Map(root=s3_uri, s3=fs, check=False)
134+
else:
135+
if not os.path.exists(uri):
136+
raise ValueError(f"Cannot find the filepath at {uri}.")
137+
store = uri
138+
139+
with zarr.open(store, "r") as f:
119140
multiscales = f.attrs["multiscales"][0]
120-
# TODO double check that the metadata is correct and transform the voxel size to a dict.
121-
# TODO voxel size is given in Angstrom, divide by 10 to get nanometer
122-
internal_path = multiscales["dataset"][scale_level]
141+
142+
# Read the axis and transformation metadata for this dataset, to determine the voxel size.
143+
axes = [axis["name"] for axis in multiscales["axes"]]
144+
assert set(axes) == set("xyz")
145+
transformations = multiscales["datasets"][scale_level]["coordinateTransformations"]
146+
scale_transformation = [trafo["scale"] for trafo in transformations if trafo["type"] == "scale"][0]
147+
148+
# The voxel size is given in angstrom, we divide it by 10 to convert it to nanometer.
149+
voxel_size = {axis: scale / 10.0 for axis, scale in zip(axes, scale_transformation)}
150+
151+
# Get the internale path for the given scale and load the data.
152+
internal_path = multiscales["datasets"][scale_level]["path"]
123153
data = f[internal_path][:]
124-
transformation = multiscales["transformation"][scale_level]
125-
voxel_size = transformation["scale"]
126154

127155
return data, voxel_size
128156

129157

130158
def read_data_from_cryo_et_portal_run(
131-
run_id: int, output_path: Optional[str] = None
159+
run_id: int,
160+
output_path: Optional[str] = None,
161+
use_zarr_format: bool = True,
162+
processing_type: str = "denoised",
163+
id_field: str = "run_id",
164+
scale_level: Optional[int] = None,
132165
) -> Tuple[np.ndarray, Dict[str, float]]:
133166
"""Read data and voxel size from a CryoET Data Portal run.
134167
135168
Args:
136169
run_id: The ID of the experiment run.
137170
output_path: The path for saving the data. The data will be streamed if the path is not given.
171+
use_zarr_format: Whether to use the data in zarr format instead of mrc.
172+
processing_type: The processing type of the tomogram to download.
173+
id_field: The name of the id field.
174+
scale_level: The scale level to read from the data. Only valid for zarr data.
138175
139176
Returns:
140177
The data read from the run.
141-
The voxel size read from the run
178+
The voxel size read from the run.
142179
"""
180+
if output_path is not None and os.path.exists(output_path):
181+
return read_ome_zarr(output_path) if use_zarr_format else read_mrc(output_path)
182+
143183
if cdp is None:
144-
raise RuntimeError("The CryoET Data portal library is required to read data from the portal.")
184+
raise RuntimeError("The CryoET data portal library is required to download data from the portal.")
185+
if s3fs is None:
186+
raise RuntimeError("The CryoET data portal download requires s3fs download.")
187+
188+
client = cdp.Client()
189+
190+
fs = s3fs.S3FileSystem(anon=True)
191+
tomograms = cdp.Tomogram.find(
192+
client, [getattr(cdp.Tomogram, id_field) == run_id, cdp.Tomogram.processing == processing_type]
193+
)
194+
if len(tomograms) == 0:
195+
return None, None
196+
if len(tomograms) > 1:
197+
raise NotImplementedError
198+
tomo = tomograms[0]
199+
200+
if use_zarr_format:
201+
if output_path is None:
202+
scale_level = 0 if scale_level is None else scale_level
203+
data, voxel_size = read_ome_zarr(tomo.s3_omezarr_dir, fs=fs)
204+
else:
205+
# TODO: write the outuput to ome zarr, for all scale levels.
206+
raise NotImplementedError
207+
else:
208+
if scale_level is not None:
209+
raise ValueError
210+
if output_path is None:
211+
raise RuntimeError("You have to pass an output_path to download the data as mrc file.")
212+
fs.get(tomo.s3_mrc_file, output_path)
213+
data, voxel_size = read_mrc(output_path)
214+
215+
return data, voxel_size

0 commit comments

Comments
 (0)