Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ dependencies:
- scikit-image
- pybdv
- pytorch
- s3fs
- torch_em
- z5py
- zarr
107 changes: 107 additions & 0 deletions flamingo_tools/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os

import s3fs
import zarr

from tqdm import tqdm

# Using incucyte s3 as a temporary measure.
MOBIE_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie"
SERVICE_ENDPOINT = "https://s3.gwdg.de/"
BUCKET_NAME = "incucyte-general/lightsheet"

# For MoBIE:
# https://s3.gwdg.de/incucyte-general/lightsheet

def check_s3_credentials(bucket_name, service_endpoint, credentials):
"""
Check if S3 parameter and credentials were set either as a function input or were exported as environment variables.
"""
if bucket_name is None:
bucket_name = os.getenv('BUCKET_NAME')
if bucket_name is None:
raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name <bucket_name>\nexport BUCKET_NAME=<bucket_name>")

if service_endpoint is None:
service_endpoint = os.getenv('SERVICE_ENDPOINT')
if service_endpoint is None:
raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint <endpoint>\nexport SERVICE_ENDPOINT=<endpoint>")

if credentials is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_key = os.getenv('AWS_SECRET_ACCESS_KEY')
if access_key is None:
raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=<access_key>")
if secret_key is None:
raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=<secret_key>")

return bucket_name, service_endpoint, credentials


def get_s3_path(
input_path,
bucket_name, service_endpoint,
credential_file=None,
):
"""
Get S3 path for a file or folder and file system based on S3 parameters and credentials.
"""
fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file)

zarr_path=f"{bucket_name}/{input_path}"

if not fs.exists(zarr_path):
print(f"Error: S3 path {zarr_path} does not exist!")

s3_path = zarr.storage.FSStore(zarr_path, fs=fs)

return s3_path, fs


def read_s3_credentials(credential_file):
key, secret = None, None
with open(credential_file) as f:
for line in f:
if line.startswith("aws_access_key_id"):
key = line.rstrip("\n").strip().split(" ")[-1]
if line.startswith("aws_secret_access_key"):
secret = line.rstrip("\n").strip().split(" ")[-1]
if key is None or secret is None:
raise ValueError(f"Invalid credential file {credential_file}")
return key, secret


def create_s3_target(url, anon=False, credential_file=None):
"""
Create file system for S3 bucket based on a service endpoint and an optional credential file.
If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables
AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.
"""
client_kwargs = {"endpoint_url": url}
if credential_file is not None:
key, secret = read_s3_credentials(credential_file)
fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs)
else:
fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs)
return fs


def upload_data():
target = create_s3_target(
SERVICE_ENDPOINT,
credential_file="./credentials.incucyte"
)
to_upload = []
for root, dirs, files in os.walk(MOBIE_FOLDER):
dirs.sort()
for ff in files:
if ff.endswith(".xml"):
to_upload.append(os.path.join(root, ff))

print("Uploading", len(to_upload), "files to")

for path in tqdm(to_upload):
rel_path = os.path.relpath(path, MOBIE_FOLDER)
target.put(
path, os.path.join(BUCKET_NAME, rel_path)
)
109 changes: 100 additions & 9 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,82 @@
import numpy as np
import vigra
import multiprocessing as mp
from concurrent import futures

from skimage import measure
from scipy.spatial import distance
from scipy.sparse import csr_matrix
from tqdm import tqdm
from sklearn.neighbors import NearestNeighbors

import elf.parallel as parallel
from elf.io import open_file
import nifty.tools as nt

def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_threshold=5):
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation, start_label=1, keep_zeros=True)
def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True):
"""
Calculate average distance of n nearest neighbors.

props = measure.regionprops(segmentation)
coordinates = np.array([prop.centroid for prop in props])
:param DataFrame tsv_table:
:param int n_neighbors: Number of nearest neighbors
:param bool expand_table: Flag for expanding DataFrame
:returns: List of average distances
:rtype: list
"""
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))

coordinates = np.array(centroids)

# nearest neighbor is always itself, so n_neighbors+=1
nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates)
distances, indices = nbrs.kneighbors(coordinates)

# Average distance to nearest neighbors
distance_avg = [sum(d) / len(d) for d in distances[:, 1:]]

if expand_table:
tsv_table['distance_nn'+str(n_neighbors)] = distance_avg

return distance_avg

def filter_isolated_objects(
segmentation, output_path, tsv_table=None,
distance_threshold=15, neighbor_threshold=5, min_size=1000,
output_key="segmentation_postprocessed",
):
"""
Postprocessing step to filter isolated objects from a segmentation.
Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them.
Additionally, size filtering is possible if a TSV file is supplied.

:param dataset segmentation: Dataset containing the segmentation
:param str out_path: Output path for postprocessed segmentation
:param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format
:param int distance_threshold: Distance in micrometer to check for neighbors
:param int neighbor_threshold: Minimal number of neighbors for filtering
:param int min_size: Minimal number of pixels for filtering small instances
:param str output_key: Output key for postprocessed segmentation
"""
if tsv_table is not None:
n_pixels = tsv_table["n_pixels"].to_list()
label_ids = tsv_table["label_id"].to_list()
centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"]))
n_ids = len(label_ids)

# filter out cells smaller than min_size
if min_size is not None:
min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size]
centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids]
label_ids = [int(l) for l in label_ids if l not in min_size_label_ids]

coordinates = np.array(centroids)
label_ids = np.array(label_ids)

else:
segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True)
props = measure.regionprops(segmentation)
coordinates = np.array([prop.centroid for prop in props])
label_ids = np.unique(segmentation)[1:]

# Calculate pairwise distances and convert to a square matrix
dist_matrix = distance.pdist(coordinates)
Expand All @@ -22,13 +88,38 @@ def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_thresh
# Sum each row to count neighbors
neighbor_counts = sparse_matrix.sum(axis=1)

seg_ids = np.unique(segmentation)[1:]
filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze()
filter_ids = seg_ids[filter_mask]
filter_ids = label_ids[filter_mask]

shape = segmentation.shape
block_shape=(128,128,128)
chunks=(128,128,128)

blocking = nt.blocking([0] * len(shape), shape, block_shape)

output = open_file(output_path, mode="a")

output_dataset = output.create_dataset(
output_key, shape=shape, dtype=segmentation.dtype,
chunks=chunks, compression="gzip"
)

def filter_chunk(block_id):
"""
Set all points within a chunk to zero if they match filter IDs.
"""
block = blocking.getBlock(block_id)
volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end))
data = segmentation[volume_index]
data[np.isin(data, filter_ids)] = 0
output_dataset[volume_index] = data

# Limit the number of cores for parallelization.
n_threads = min(16, mp.cpu_count())

seg_filtered = segmentation.copy()
seg_filtered[np.isin(seg_filtered, filter_ids)] = 0
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks))

seg_filtered, n_ids_filtered, _ = vigra.analysis.relabelConsecutive(seg_filtered, start_label=1, keep_zeros=True)
seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128))

return seg_filtered, n_ids, n_ids_filtered
Loading