Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ dependencies:
- scikit-image
- pybdv
- pytorch
- s3fs
- torch_em
- z5py
- zarr
111 changes: 111 additions & 0 deletions flamingo_tools/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os

import s3fs
import zarr

"""
This script contains utility functions for processing data located on an S3 storage.
The upload of data to the storage system should be performed with 'rclone'.
"""

# Dedicated bucket for cochlea lightsheet project
MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet"
SERVICE_ENDPOINT = "https://s3.gwdg.de/"
BUCKET_NAME = "cochlea-lightsheet"

DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials")

# For MoBIE:
# https://s3.gwdg.de/incucyte-general/lightsheet

def check_s3_credentials(bucket_name, service_endpoint, credential_file):
"""
Check if S3 parameter and credentials were set either as a function input or were exported as environment variables.
"""
if bucket_name is None:
bucket_name = os.getenv('BUCKET_NAME')
if bucket_name is None:
if BUCKET_NAME in globals():
bucket_name = BUCKET_NAME
else:
raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name <bucket_name>\nexport BUCKET_NAME=<bucket_name>")

if service_endpoint is None:
service_endpoint = os.getenv('SERVICE_ENDPOINT')
if service_endpoint is None:
if SERVICE_ENDPOINT in globals():
service_endpoint = SERVICE_ENDPOINT
else:
raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint <endpoint>\nexport SERVICE_ENDPOINT=<endpoint>")

if credential_file is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')

# check for default credentials if no credential_file is provided
if access_key is None:
if os.path.isfile(DEFAULT_CREDENTIALS):
access_key, _ = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
else:
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=<access_key>")
if secret_key is None:
# check for default credentials
if os.path.isfile(DEFAULT_CREDENTIALS):
_, secret_key = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS)
else:
raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=<secret_key>")

else:
# check validity of credential file
_, _ = read_s3_credentials(credential_file=credential_file)

return bucket_name, service_endpoint, credential_file

def get_s3_path(
input_path,
bucket_name=None, service_endpoint=None,
credential_file=None,
):
"""
Get S3 path for a file or folder and file system based on S3 parameters and credentials.
"""
bucket_name, service_endpoint, credential_file = check_s3_credentials(bucket_name, service_endpoint, credential_file)

fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file)

zarr_path=f"{bucket_name}/{input_path}"

if not fs.exists(zarr_path):
print(f"Error: S3 path {zarr_path} does not exist!")

s3_path = zarr.storage.FSStore(zarr_path, fs=fs)

return s3_path, fs


def read_s3_credentials(credential_file):
key, secret = None, None
with open(credential_file) as f:
for line in f:
if line.startswith("aws_access_key_id"):
key = line.rstrip("\n").strip().split(" ")[-1]
if line.startswith("aws_secret_access_key"):
secret = line.rstrip("\n").strip().split(" ")[-1]
if key is None or secret is None:
raise ValueError(f"Invalid credential file {credential_file}")
return key, secret


def create_s3_target(url, anon=False, credential_file=None):
"""
Create file system for S3 bucket based on a service endpoint and an optional credential file.
If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables
AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.
"""
client_kwargs = {"endpoint_url": url}
if credential_file is not None:
key, secret = read_s3_credentials(credential_file)
fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs)
else:
fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs)
return fs
109 changes: 100 additions & 9 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,82 @@
import numpy as np
import vigra
import multiprocessing as mp
from concurrent import futures

from skimage import measure
from scipy.spatial import distance
from scipy.sparse import csr_matrix
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors

import elf.parallel as parallel
from elf.io import open_file
import nifty.tools as nt

def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_threshold=5):
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation, start_label=1, keep_zeros=True)
def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True):
"""
Calculate average distance of n nearest neighbors.

props = measure.regionprops(segmentation)
coordinates = np.array([prop.centroid for prop in props])
:param DataFrame tsv_table:
:param int n_neighbors: Number of nearest neighbors
:param bool expand_table: Flag for expanding DataFrame
:returns: List of average distances
:rtype: list
"""
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))

coordinates = np.array(centroids)

# nearest neighbor is always itself, so n_neighbors+=1
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates)
distances, indices = nbrs.kneighbors(coordinates)

# Average distance to nearest neighbors
distance_avg = [sum(d) / len(d) for d in distances[:, 1:]]

if expand_table:
tsv_table['distance_nn'+str(n_neighbors)] = distance_avg

return distance_avg

def filter_isolated_objects(
segmentation, output_path, tsv_table=None,
distance_threshold=15, neighbor_threshold=5, min_size=1000,
output_key="segmentation_postprocessed",
):
"""
Postprocessing step to filter isolated objects from a segmentation.
Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them.
Additionally, size filtering is possible if a TSV file is supplied.

:param dataset segmentation: Dataset containing the segmentation
:param str out_path: Output path for postprocessed segmentation
:param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
:param int distance_threshold: Distance in micrometer to check for neighbors
:param int neighbor_threshold: Minimal number of neighbors for filtering
:param int min_size: Minimal number of pixels for filtering small instances
:param str output_key: Output key for postprocessed segmentation
"""
if tsv_table is not None:
n_pixels = tsv_table["n_pixels"].to_list()
label_ids = tsv_table["label_id"].to_list()
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
n_ids = len(label_ids)

# filter out cells smaller than min_size
if min_size is not None:
min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size]
centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids]
label_ids = [int(l) for l in label_ids if l not in min_size_label_ids]

coordinates = np.array(centroids)
label_ids = np.array(label_ids)

else:
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
props = measure.regionprops(segmentation)
coordinates = np.array([prop.centroid for prop in props])
label_ids = np.unique(segmentation)[1:]

# Calculate pairwise distances and convert to a square matrix
dist_matrix = distance.pdist(coordinates)
Expand All @@ -22,13 +88,38 @@ def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_thresh
# Sum each row to count neighbors
neighbor_counts = sparse_matrix.sum(axis=1)

seg_ids = np.unique(segmentation)[1:]
filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze()
filter_ids = seg_ids[filter_mask]
filter_ids = label_ids[filter_mask]

shape = segmentation.shape
block_shape=(128,128,128)
chunks=(128,128,128)

blocking = nt.blocking([0] * len(shape), shape, block_shape)

output = open_file(output_path, mode="a")

output_dataset = output.create_dataset(
output_key, shape=shape, dtype=segmentation.dtype,
chunks=chunks, compression="gzip"
)

def filter_chunk(block_id):
"""
Set all points within a chunk to zero if they match filter IDs.
"""
block = blocking.getBlock(block_id)
volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
data = segmentation[volume_index]
data[np.isin(data, filter_ids)] = 0
output_dataset[volume_index] = data

# Limit the number of cores for parallelization.
n_threads = min(16, mp.cpu_count())

seg_filtered = segmentation.copy()
seg_filtered[np.isin(seg_filtered, filter_ids)] = 0
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))

seg_filtered, n_ids_filtered, _ = vigra.analysis.relabelConsecutive(seg_filtered, start_label=1, keep_zeros=True)
seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128))

return seg_filtered, n_ids, n_ids_filtered
Loading