Skip to content

Commit 06e6740

Browse files
committed
Initial adaptation to work with S3 data
1 parent f12f9d3 commit 06e6740

File tree

6 files changed

+280
-56
lines changed

6 files changed

+280
-56
lines changed
Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,55 @@
11
import numpy as np
22
import vigra
3+
import multiprocessing as mp
4+
from concurrent import futures
35

46
from skimage import measure
57
from scipy.spatial import distance
68
from scipy.sparse import csr_matrix
9+
from tqdm import tqdm
710

11+
import elf.parallel as parallel
12+
from elf.io import open_file
13+
import nifty.tools as nt
814

9-
def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_threshold=5):
10-
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation, start_label=1, keep_zeros=True)
15+
def filter_isolated_objects(
16+
segmentation, output_path, tsv_table=None,
17+
distance_threshold=15, neighbor_threshold=5, min_size=1000,
18+
output_key="segmentation_postprocessed",
19+
):
20+
"""
21+
Postprocessing step to filter isolated objects from a segmentation.
22+
Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them.
23+
Additionally, size filtering is possible if a TSV file is supplied.
1124
12-
props = measure.regionprops(segmentation)
13-
coordinates = np.array([prop.centroid for prop in props])
25+
:param dataset segmentation: Dataset containing the segmentation
26+
:param str out_path: Output path for postprocessed segmentation
27+
:param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
28+
:param int distance_threshold: Distance in micrometer to check for neighbors
29+
:param int neighbor_threshold: Minimal number of neighbors for filtering
30+
:param int min_size: Minimal number of pixels for filtering small instances
31+
:param str output_key: Output key for postprocessed segmentation
32+
"""
33+
if tsv_table is not None:
34+
n_pixels = tsv_table["n_pixels"].to_list()
35+
label_ids = tsv_table["label_id"].to_list()
36+
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
37+
n_ids = len(label_ids)
38+
39+
# filter out cells smaller than min_size
40+
if min_size is not None:
41+
min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size]
42+
centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids]
43+
label_ids = [int(l) for l in label_ids if l not in min_size_label_ids]
44+
45+
coordinates = np.array(centroids)
46+
label_ids = np.array(label_ids)
47+
48+
else:
49+
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
50+
props = measure.regionprops(segmentation)
51+
coordinates = np.array([prop.centroid for prop in props])
52+
label_ids = np.unique(segmentation)[1:]
1453

1554
# Calculate pairwise distances and convert to a square matrix
1655
dist_matrix = distance.pdist(coordinates)
@@ -22,13 +61,38 @@ def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_thresh
2261
# Sum each row to count neighbors
2362
neighbor_counts = sparse_matrix.sum(axis=1)
2463

25-
seg_ids = np.unique(segmentation)[1:]
2664
filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze()
27-
filter_ids = seg_ids[filter_mask]
65+
filter_ids = label_ids[filter_mask]
66+
67+
shape = segmentation.shape
68+
block_shape=(128,128,128)
69+
chunks=(128,128,128)
70+
71+
blocking = nt.blocking([0] * len(shape), shape, block_shape)
72+
73+
output = open_file(output_path, mode="a")
74+
75+
output_dataset = output.create_dataset(
76+
output_key, shape=shape, dtype=segmentation.dtype,
77+
chunks=chunks, compression="gzip"
78+
)
79+
80+
def filter_chunk(block_id):
81+
"""
82+
Set all points within a chunk to zero if they match filter IDs.
83+
"""
84+
block = blocking.getBlock(block_id)
85+
volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
86+
data = segmentation[volume_index]
87+
data[np.isin(data, filter_ids)] = 0
88+
output_dataset[volume_index] = data
89+
90+
# Limit the number of cores for parallelization.
91+
n_threads = min(16, mp.cpu_count())
2892

29-
seg_filtered = segmentation.copy()
30-
seg_filtered[np.isin(seg_filtered, filter_ids)] = 0
93+
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
94+
list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))
3195

32-
seg_filtered, n_ids_filtered, _ = vigra.analysis.relabelConsecutive(seg_filtered, start_label=1, keep_zeros=True)
96+
seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128))
3397

3498
return seg_filtered, n_ids, n_ids_filtered

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import multiprocessing as mp
22
import os
3+
import sys
34
import warnings
45
from concurrent import futures
56

@@ -10,6 +11,7 @@
1011
import vigra
1112
import torch
1213
import z5py
14+
import zarr
1315
import json
1416

1517
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
@@ -18,6 +20,10 @@
1820
from torch_em.util import load_model
1921
from torch_em.util.prediction import predict_with_halo
2022
from tqdm import tqdm
23+
from inspect import getsourcefile
24+
25+
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(getsourcefile(lambda:0)))), "scripts", "prediction"))
26+
import upload_to_s3
2127

2228
"""
2329
Prediction using distance U-Net.
@@ -43,7 +49,7 @@ def ndim(self):
4349
return self._volume.ndim - 1
4450

4551

46-
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None):
52+
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None, s3=None):
4753
with warnings.catch_warnings():
4854
warnings.simplefilter("ignore")
4955
if os.path.isdir(model_path):
@@ -56,6 +62,9 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
5662

5763
if input_key is None:
5864
input_ = imageio.imread(input_path)
65+
elif s3 is not None:
66+
with zarr.open(input_path, mode="r") as f:
67+
input_ = f[input_key]
5968
else:
6069
input_ = open_file(input_path, "r")[input_key]
6170

@@ -138,7 +147,7 @@ def postprocess(x):
138147
return original_shape
139148

140149

141-
def find_mask(input_path, input_key, output_folder):
150+
def find_mask(input_path, input_key, output_folder, s3=None):
142151
mask_path = os.path.join(output_folder, "mask.zarr")
143152
f = z5py.File(mask_path, "a")
144153

@@ -149,6 +158,10 @@ def find_mask(input_path, input_key, output_folder):
149158
if input_key is None:
150159
raw = imageio.imread(input_path)
151160
chunks = (64, 64, 64)
161+
elif s3 is not None:
162+
with zarr.open(input_path, mode="r") as fin:
163+
raw = fin[input_key]
164+
chunks = raw.chunks
152165
else:
153166
fin = open_file(input_path, "r")
154167
raw = fin[input_key]
@@ -243,7 +256,10 @@ def write_block(block_id):
243256
tp.map(write_block, range(blocking.numberOfBlocks))
244257

245258

246-
def calc_mean_and_std(input_path, input_key, output_folder):
259+
def calc_mean_and_std(
260+
input_path, input_key, output_folder,
261+
s3=None,
262+
):
247263
"""
248264
Calculate mean and standard deviation of full volume.
249265
Parameters are saved in 'mean_std.json' within the output folder.
@@ -254,6 +270,9 @@ def calc_mean_and_std(input_path, input_key, output_folder):
254270

255271
if input_key is None:
256272
input_ = imageio.imread(input_path)
273+
elif s3 is not None:
274+
with zarr.open(input_path, mode="r") as f:
275+
input_ = f[input_key]
257276
else:
258277
input_ = open_file(input_path, "r")[input_key]
259278

@@ -267,6 +286,7 @@ def calc_mean_and_std(input_path, input_key, output_folder):
267286
with open(json_file, "w") as f:
268287
json.dump(ddict, f)
269288

289+
270290
def run_unet_prediction(
271291
input_path, input_key,
272292
output_folder, model_path,
@@ -288,32 +308,63 @@ def run_unet_prediction(
288308

289309
def run_unet_prediction_preprocess_slurm(
290310
input_path, input_key, output_folder,
311+
s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None,
291312
):
292313
"""
293314
Pre-processing for the parallel prediction with U-Net models.
294315
Masks are stored in mask.zarr in the output folder.
295316
The mean and standard deviation are precomputed for later usage during prediction
296-
and stored in a JSON file within the output folder as mean_std.json
317+
and stored in a JSON file within the output folder as mean_std.json.
297318
"""
298-
find_mask(input_path, input_key, output_folder)
299-
calc_mean_and_std(input_path, input_key, output_folder)
319+
if s3 is not None:
320+
bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
321+
322+
input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
323+
324+
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
325+
find_mask(input_path, input_key, output_folder, s3=s3)
326+
327+
calc_mean_and_std(input_path, input_key, output_folder, s3=s3)
328+
300329

301330
def run_unet_prediction_slurm(
302331
input_path, input_key, output_folder, model_path,
303332
scale=None,
304333
block_shape=None, halo=None, prediction_instances=1,
334+
s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None,
305335
):
336+
"""
337+
Run prediction of distance U-Net for data stored locally or on an S3 bucket.
338+
339+
:param str input_path: File path to input data
340+
:param str input_key: Input key for data in ome.zarr format
341+
:param str output_folder: Output folder for prediction.zarr
342+
:param str model_path: File path to distance U-Net model
343+
:param float scale:
344+
:param tuple block_shape:
345+
:param tuple halo:
346+
:param int prediction_instances: Number of workers for parallel computation within slurm array
347+
:param bool s3: Flag for accessing data on S3 bucket
348+
:param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported
349+
:param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported
350+
:param str s3_credentials: Path to file containing S3 credentials
351+
"""
306352
os.makedirs(output_folder, exist_ok=True)
307353
prediction_instances = int(prediction_instances)
308354
slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
309355

356+
if s3 is not None:
357+
bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
358+
359+
input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
360+
310361
if slurm_task_id is not None:
311362
slurm_task_id = int(slurm_task_id)
312363
else:
313364
raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.")
314365

315366
if not os.path.isdir(os.path.join(output_folder, "mask.zarr")):
316-
find_mask(input_path, input_key, output_folder)
367+
find_mask(input_path, input_key, output_folder, s3=s3)
317368

318369
# get pre-computed mean and standard deviation of full volume from JSON file
319370
if os.path.isfile(os.path.join(output_folder, "mean_std.json")):
@@ -328,9 +379,10 @@ def run_unet_prediction_slurm(
328379
original_shape = prediction_impl(
329380
input_path, input_key, output_folder, model_path, scale, block_shape, halo,
330381
prediction_instances=prediction_instances, slurm_task_id=slurm_task_id,
331-
mean=mean, std=std,
382+
mean=mean, std=std, s3=s3,
332383
)
333384

385+
334386
# does NOT need GPU, FIXME: only run on CPU
335387
def run_unet_segmentation_slurm(output_folder, min_size):
336388
min_size = int(min_size)

scripts/extract_block.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import os
2+
import sys
23
import argparse
34
import numpy as np
45
import z5py
56
import zarr
67

7-
import s3fs
8+
from inspect import getsourcefile
9+
10+
sys.path.append(os.path.join(os.path.dirname(getsourcefile(lambda:0)), "prediction"))
11+
import upload_to_s3
812

913
"""
1014
This script extracts data around an input center coordinate in a given ROI halo.
@@ -18,7 +22,10 @@
1822
"""
1923

2024

21-
def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
25+
def main(
26+
input_file, output_dir, coords, input_key, resolution, roi_halo,
27+
s3, s3_credentials, s3_bucket_name, s3_service_endpoint,
28+
):
2229
"""
2330
2431
:param str input_file: File path to input folder in n5 format
@@ -28,6 +35,9 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
2835
:param str coords: Center coordinates of extracted 3D volume in format 'x,y,z'
2936
:param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z'
3037
:param bool s3: Flag for using an S3 bucket
38+
:param str s3_credentials: Path to file containing S3 credentials
39+
:param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported
40+
:param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported
3141
"""
3242

3343
coords = [int(r) for r in coords.split(",")]
@@ -61,33 +71,18 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
6171
roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo))
6272

6373
if s3:
74+
bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials)
6475

65-
# Define S3 bucket and OME-Zarr dataset path
66-
67-
bucket_name = "cochlea-lightsheet"
68-
zarr_path = f"{bucket_name}/{input_file}"
69-
70-
# Create an S3 filesystem
71-
fs = s3fs.S3FileSystem(
72-
client_kwargs={"endpoint_url": "https://s3.fs.gwdg.de"},
73-
anon=False
74-
)
76+
s3_path, fs = upload_to_s3.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials)
7577

76-
if not fs.exists(zarr_path):
77-
print("Error: Path does not exist!")
78-
79-
# Open the OME-Zarr dataset
80-
store = zarr.storage.FSStore(zarr_path, fs=fs)
81-
print(f"Opening file {zarr_path} from the S3 bucket.")
82-
83-
with zarr.open(store, mode="r") as f:
78+
with zarr.open(s3_path, mode="r") as f:
8479
raw = f[input_key][roi]
8580

8681
else:
87-
with z5py.File(input_file, "r") as f:
82+
with zarr.open(input_file, mode="r") as f:
8883
raw = f[input_key][roi]
8984

90-
with z5py.File(output_file, "w") as f_out:
85+
with zarr.open(output_file, mode="w") as f_out:
9186
f_out.create_dataset("raw", data=raw, compression="gzip")
9287

9388
if __name__ == "__main__":
@@ -103,8 +98,15 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3):
10398
parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer")
10499

105100
parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'")
101+
106102
parser.add_argument("--s3", action="store_true", help="Use S3 bucket")
103+
parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials")
104+
parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name")
105+
parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint")
107106

108107
args = parser.parse_args()
109108

110-
main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo, args.s3)
109+
main(
110+
args.input, args.output, args.coord, args.input_key, args.resolution, args.roi_halo,
111+
args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint,
112+
)

0 commit comments

Comments
 (0)