Skip to content

Commit 86f60c2

Browse files
Update file extraction functionality
1 parent 7027e74 commit 86f60c2

File tree

2 files changed

+108
-7
lines changed

2 files changed

+108
-7
lines changed

scripts/export_synapse_detections.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,27 @@ def export_synapse_detections(
2323
radius: float,
2424
id_offset: int,
2525
filter_ihc_components: List[int],
26+
position: str,
27+
halo: List[int],
28+
as_float: bool = False,
29+
use_syn_ids: bool = False,
2630
):
27-
"""Export synapse detections fro lower resolutions.
31+
"""Export synapse detections from S3..
2832
2933
Args:
3034
cochlea: Cochlea name on S3 bucket.
3135
scales: Scale for export of lower resolution.
32-
output_folder:
33-
synapse_name:
36+
output_folder: The output folder for saving the exported data.
37+
synapse_name: The name of the synapse detection source.
3438
reference_ihcs: Name of IHC segmentation.
3539
max_dist: Maximal distance of synapse to IHC segmentation.
36-
radius:
40+
radius: The radius for writing the synapse points to the output volume.
3741
id_offset: Offset of label id of synapse output to have different colours for visualization.
3842
filter_ihc_components: Component label(s) for filtering IHC segmentation.
43+
position: Optional position for extracting a crop from the data. Requires to also pass halo.
44+
halo: Halo for extracting a crop from the data.
45+
as_float: Whether to save the exported data as floating point values.
46+
use_syn_ids: Whether to write the synapse IDs or the matched IHC IDs to the output volume.
3947
"""
4048
s3 = create_s3_target()
4149

@@ -79,17 +87,37 @@ def export_synapse_detections(
7987
coordinates = np.round(coordinates, 0).astype("int")
8088

8189
ihc_ids = syn_table["matched_ihc"].values
90+
syn_ids = syn_table["spot_id"].values
91+
92+
if position is not None:
93+
assert halo is not None
94+
center = json.loads(position)
95+
assert len(halo) == len(center)
96+
center = [int(ce / (resolution * (2 ** scale))) for ce in center[::-1]]
97+
start = np.array([max(0, ce - ha) for ce, ha in zip(center, halo)])[None]
98+
stop = np.array([min(sh, ce + ha) for ce, ha, sh in zip(center, halo, shape)])[None]
99+
100+
mask = ((coordinates >= start) & (coordinates < stop)).all(axis=1)
101+
coordinates = coordinates[mask]
102+
coordinates -= start
103+
104+
ihc_ids = ihc_ids[mask]
105+
syn_ids = syn_ids[mask]
106+
107+
shape = tuple(int(sto - sta) for sta, sto in zip(start.squeeze(), stop.squeeze()))
82108

83109
# Create the output.
84110
output = np.zeros(shape, dtype="uint16")
85111
mask = ball(radius).astype(bool)
86112

87-
for coord, matched_ihc in tqdm(
88-
zip(coordinates, ihc_ids), total=len(coordinates), desc="Writing synapses to volume"
113+
ids = syn_ids if use_syn_ids else ihc_ids
114+
115+
for coord, syn_id in tqdm(
116+
zip(coordinates, ids), total=len(coordinates), desc="Writing synapses to volume"
89117
):
90118
bb = tuple(slice(c - radius, c + radius + 1) for c in coord)
91119
try:
92-
output[bb][mask] = matched_ihc + id_offset
120+
output[bb][mask] = syn_id + id_offset
93121
except IndexError:
94122
print("Index error for", coord)
95123
continue
@@ -101,6 +129,10 @@ def export_synapse_detections(
101129
out_path = os.path.join(out_folder, f"{synapse_name}_offset{id_offset}.tif")
102130
else:
103131
out_path = os.path.join(out_folder, f"{synapse_name}.tif")
132+
133+
if as_float:
134+
output = output.astype("float32")
135+
104136
print("Writing synapses to", out_path)
105137
tifffile.imwrite(out_path, output, bigtiff=True, compression="zlib")
106138

@@ -116,13 +148,19 @@ def main():
116148
parser.add_argument("--radius", type=int, default=3)
117149
parser.add_argument("--id_offset", type=int, default=0)
118150
parser.add_argument("--filter_ihc_components", nargs="+", type=int, default=[1])
151+
parser.add_argument("--position", default=None)
152+
parser.add_argument("--halo", default=None, nargs="+", type=int)
153+
parser.add_argument("--as_float", action="store_true")
154+
parser.add_argument("--use_syn_ids", action="store_true")
119155
args = parser.parse_args()
120156

121157
export_synapse_detections(
122158
args.cochlea, args.scale, args.output_folder,
123159
args.synapse_name, args.reference_ihcs,
124160
args.max_dist, args.radius,
125161
args.id_offset, args.filter_ihc_components,
162+
position=args.position, halo=args.halo,
163+
as_float=args.as_float, use_syn_ids=args.use_syn_ids,
126164
)
127165

128166

scripts/extract_block_from_s3.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import json
3+
4+
import numpy as np
5+
import pandas as pd
6+
import tifffile
7+
import zarr
8+
from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
9+
10+
11+
def extract_block_from_s3(args):
12+
os.makedirs(args.output_folder, exist_ok=True)
13+
14+
resolution = 0.38 * (2 ** args.scale)
15+
center = json.loads(args.position)
16+
center = [int(ce / resolution) for ce in center[::-1]]
17+
18+
for source in args.sources:
19+
print("Extracting source:", source, "from", args.cochlea)
20+
internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{source}.ome.zarr")
21+
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
22+
23+
input_key = f"s{args.scale}"
24+
with zarr.open(s3_store, mode="r") as f:
25+
ds = f[input_key]
26+
roi = tuple(slice(max(0, ce - ha), min(sh, ce + ha)) for ce, ha, sh in zip(center, args.halo, ds.shape))
27+
data = ds[roi]
28+
29+
if args.component_ids is not None:
30+
table_path = os.path.join(BUCKET_NAME, args.cochlea, "tables", source, "default.tsv")
31+
with fs.open(table_path, "r") as f:
32+
table = pd.read_csv(f, sep="\t")
33+
keep_ids = table[table.component_labels.isin(args.component_ids)].label_id.values
34+
mask = np.isin(data, keep_ids)
35+
data[~mask] = 0
36+
37+
coord_string = "-".join([str(c).zfill(4) for c in center])
38+
out_path = os.path.join(args.output_folder, f"{args.cochlea}_{source}_scale{args.scale}_{coord_string}.tif")
39+
40+
if args.as_float:
41+
data = data.astype("float32")
42+
tifffile.imwrite(out_path, data, compression="zlib")
43+
44+
45+
def main():
46+
import argparse
47+
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--cochlea", "-c", required=True)
50+
parser.add_argument("--output_folder", "-o", required=True)
51+
parser.add_argument("--sources", "-s", required=True, nargs="+")
52+
parser.add_argument("--position", "-p", required=True)
53+
parser.add_argument("--halo", nargs="+", type=int, default=[32, 128, 128])
54+
parser.add_argument("--scale", type=int, default=0)
55+
parser.add_argument("--as_float", action="store_true")
56+
parser.add_argument("--component_ids", type=int, nargs="+")
57+
args = parser.parse_args()
58+
59+
extract_block_from_s3(args)
60+
61+
62+
if __name__ == "__main__":
63+
main()

0 commit comments

Comments
 (0)