Skip to content

Commit ef8a5da

Browse files
authored
Initial adaptation to work with S3 data (#23)
Initial adaptation to work with S3 data
1 parent f12f9d3 commit ef8a5da

File tree

9 files changed

+439
-120
lines changed

9 files changed

+439
-120
lines changed

environment.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@ dependencies:
1010
- scikit-image
1111
- pybdv
1212
- pytorch
13+
- s3fs
1314
- torch_em
1415
- z5py
16+
- zarr

flamingo_tools/s3_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
3+
import s3fs
4+
import zarr
5+
6+
"""
7+
This script contains utility functions for processing data located on an S3 storage.
8+
The upload of data to the storage system should be performed with 'rclone'.
9+
"""
10+
11+
# Dedicated bucket for cochlea lightsheet project
12+
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
13+
SERVICE_ENDPOINT = "https://s3.gwdg.de/"
14+
BUCKET_NAME = "cochlea-lightsheet"
15+
16+
DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials")
17+
18+
# For MoBIE:
19+
# https://s3.gwdg.de/incucyte-general/lightsheet
20+
21+
def check_s3_credentials(bucket_name, service_endpoint, credential_file):
22+
"""
23+
Check if S3 parameter and credentials were set either as a function input or were exported as environment variables.
24+
"""
25+
if bucket_name is None:
26+
bucket_name = os.getenv('BUCKET_NAME')
27+
if bucket_name is None:
28+
if BUCKET_NAME in globals():
29+
bucket_name = BUCKET_NAME
30+
else:
31+
raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name <bucket_name>\nexport BUCKET_NAME=<bucket_name>")
32+
33+
if service_endpoint is None:
34+
service_endpoint = os.getenv('SERVICE_ENDPOINT')
35+
if service_endpoint is None:
36+
if SERVICE_ENDPOINT in globals():
37+
service_endpoint = SERVICE_ENDPOINT
38+
else:
39+
raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint <endpoint>\nexport SERVICE_ENDPOINT=<endpoint>")
40+
41+
if credential_file is None:
42+
access_key = os.getenv('AWS_ACCESS_KEY_ID')
43+
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
44+
45+
# check for default credentials if no credential_file is provided
46+
if access_key is None:
47+
if os.path.isfile(DEFAULT_CREDENTIALS):
48+
access_key, _ = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
49+
else:
50+
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=<access_key>")
51+
if secret_key is None:
52+
# check for default credentials
53+
if os.path.isfile(DEFAULT_CREDENTIALS):
54+
_, secret_key = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
55+
else:
56+
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=<secret_key>")
57+
58+
else:
59+
# check validity of credential file
60+
_, _ = read_s3_credentials(credential_file=credential_file)
61+
62+
return bucket_name, service_endpoint, credential_file
63+
64+
def get_s3_path(
65+
input_path,
66+
bucket_name=None, service_endpoint=None,
67+
credential_file=None,
68+
):
69+
"""
70+
Get S3 path for a file or folder and file system based on S3 parameters and credentials.
71+
"""
72+
bucket_name, service_endpoint, credential_file = check_s3_credentials(bucket_name, service_endpoint, credential_file)
73+
74+
fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file)
75+
76+
zarr_path=f"{bucket_name}/{input_path}"
77+
78+
if not fs.exists(zarr_path):
79+
print(f"Error: S3 path {zarr_path} does not exist!")
80+
81+
s3_path = zarr.storage.FSStore(zarr_path, fs=fs)
82+
83+
return s3_path, fs
84+
85+
86+
def read_s3_credentials(credential_file):
87+
key, secret = None, None
88+
with open(credential_file) as f:
89+
for line in f:
90+
if line.startswith("aws_access_key_id"):
91+
key = line.rstrip("\n").strip().split(" ")[-1]
92+
if line.startswith("aws_secret_access_key"):
93+
secret = line.rstrip("\n").strip().split(" ")[-1]
94+
if key is None or secret is None:
95+
raise ValueError(f"Invalid credential file {credential_file}")
96+
return key, secret
97+
98+
99+
def create_s3_target(url, anon=False, credential_file=None):
100+
"""
101+
Create file system for S3 bucket based on a service endpoint and an optional credential file.
102+
If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables
103+
AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.
104+
"""
105+
client_kwargs = {"endpoint_url": url}
106+
if credential_file is not None:
107+
key, secret = read_s3_credentials(credential_file)
108+
fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs)
109+
else:
110+
fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs)
111+
return fs
Lines changed: 100 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,82 @@
11
import numpy as np
22
import vigra
3+
import multiprocessing as mp
4+
from concurrent import futures
35

46
from skimage import measure
57
from scipy.spatial import distance
68
from scipy.sparse import csr_matrix
9+
from tqdm import tqdm
10+
from sklearn.neighbors import NearestNeighbors
711

12+
import elf.parallel as parallel
13+
from elf.io import open_file
14+
import nifty.tools as nt
815

9-
def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_threshold=5):
10-
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation, start_label=1, keep_zeros=True)
16+
def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True):
17+
"""
18+
Calculate average distance of n nearest neighbors.
1119
12-
props = measure.regionprops(segmentation)
13-
coordinates = np.array([prop.centroid for prop in props])
20+
:param DataFrame tsv_table:
21+
:param int n_neighbors: Number of nearest neighbors
22+
:param bool expand_table: Flag for expanding DataFrame
23+
:returns: List of average distances
24+
:rtype: list
25+
"""
26+
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
27+
28+
coordinates = np.array(centroids)
29+
30+
# nearest neighbor is always itself, so n_neighbors+=1
31+
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates)
32+
distances, indices = nbrs.kneighbors(coordinates)
33+
34+
# Average distance to nearest neighbors
35+
distance_avg = [sum(d) / len(d) for d in distances[:, 1:]]
36+
37+
if expand_table:
38+
tsv_table['distance_nn'+str(n_neighbors)] = distance_avg
39+
40+
return distance_avg
41+
42+
def filter_isolated_objects(
43+
segmentation, output_path, tsv_table=None,
44+
distance_threshold=15, neighbor_threshold=5, min_size=1000,
45+
output_key="segmentation_postprocessed",
46+
):
47+
"""
48+
Postprocessing step to filter isolated objects from a segmentation.
49+
Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them.
50+
Additionally, size filtering is possible if a TSV file is supplied.
51+
52+
:param dataset segmentation: Dataset containing the segmentation
53+
:param str out_path: Output path for postprocessed segmentation
54+
:param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
55+
:param int distance_threshold: Distance in micrometer to check for neighbors
56+
:param int neighbor_threshold: Minimal number of neighbors for filtering
57+
:param int min_size: Minimal number of pixels for filtering small instances
58+
:param str output_key: Output key for postprocessed segmentation
59+
"""
60+
if tsv_table is not None:
61+
n_pixels = tsv_table["n_pixels"].to_list()
62+
label_ids = tsv_table["label_id"].to_list()
63+
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
64+
n_ids = len(label_ids)
65+
66+
# filter out cells smaller than min_size
67+
if min_size is not None:
68+
min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size]
69+
centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids]
70+
label_ids = [int(l) for l in label_ids if l not in min_size_label_ids]
71+
72+
coordinates = np.array(centroids)
73+
label_ids = np.array(label_ids)
74+
75+
else:
76+
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
77+
props = measure.regionprops(segmentation)
78+
coordinates = np.array([prop.centroid for prop in props])
79+
label_ids = np.unique(segmentation)[1:]
1480

1581
# Calculate pairwise distances and convert to a square matrix
1682
dist_matrix = distance.pdist(coordinates)
@@ -22,13 +88,38 @@ def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_thresh
2288
# Sum each row to count neighbors
2389
neighbor_counts = sparse_matrix.sum(axis=1)
2490

25-
seg_ids = np.unique(segmentation)[1:]
2691
filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze()
27-
filter_ids = seg_ids[filter_mask]
92+
filter_ids = label_ids[filter_mask]
93+
94+
shape = segmentation.shape
95+
block_shape=(128,128,128)
96+
chunks=(128,128,128)
97+
98+
blocking = nt.blocking([0] * len(shape), shape, block_shape)
99+
100+
output = open_file(output_path, mode="a")
101+
102+
output_dataset = output.create_dataset(
103+
output_key, shape=shape, dtype=segmentation.dtype,
104+
chunks=chunks, compression="gzip"
105+
)
106+
107+
def filter_chunk(block_id):
108+
"""
109+
Set all points within a chunk to zero if they match filter IDs.
110+
"""
111+
block = blocking.getBlock(block_id)
112+
volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
113+
data = segmentation[volume_index]
114+
data[np.isin(data, filter_ids)] = 0
115+
output_dataset[volume_index] = data
116+
117+
# Limit the number of cores for parallelization.
118+
n_threads = min(16, mp.cpu_count())
28119

29-
seg_filtered = segmentation.copy()
30-
seg_filtered[np.isin(seg_filtered, filter_ids)] = 0
120+
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
121+
list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))
31122

32-
seg_filtered, n_ids_filtered, _ = vigra.analysis.relabelConsecutive(seg_filtered, start_label=1, keep_zeros=True)
123+
seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128))
33124

34125
return seg_filtered, n_ids, n_ids_filtered

0 commit comments

Comments
 (0)