From 06e674029c68f1b20b567af89f7a11601eb64358 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Wed, 9 Apr 2025 17:02:26 +0200 Subject: [PATCH 1/7] Initial adaptation to work with S3 data --- flamingo_tools/segmentation/postprocessing.py | 82 +++++++++++++++++-- .../segmentation/unet_prediction.py | 68 +++++++++++++-- scripts/extract_block.py | 48 +++++------ scripts/prediction/count_cells.py | 31 +++++-- scripts/prediction/postprocess_seg.py | 57 ++++++++++--- scripts/prediction/upload_to_s3.py | 50 +++++++++++ 6 files changed, 280 insertions(+), 56 deletions(-) diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index eeb466c..46b6bcd 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -1,16 +1,55 @@ import numpy as np import vigra +import multiprocessing as mp +from concurrent import futures from skimage import measure from scipy.spatial import distance from scipy.sparse import csr_matrix +from tqdm import tqdm +import elf.parallel as parallel +from elf.io import open_file +import nifty.tools as nt -def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_threshold=5): - segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation, start_label=1, keep_zeros=True) +def filter_isolated_objects( + segmentation, output_path, tsv_table=None, + distance_threshold=15, neighbor_threshold=5, min_size=1000, + output_key="segmentation_postprocessed", + ): + """ + Postprocessing step to filter isolated objects from a segmentation. + Instance segmentations are filtered if they have fewer neighbors than a given threshold in a given distance around them. + Additionally, size filtering is possible if a TSV file is supplied. - props = measure.regionprops(segmentation) - coordinates = np.array([prop.centroid for prop in props]) + :param dataset segmentation: Dataset containing the segmentation + :param str out_path: Output path for postprocessed segmentation + :param str tsv_file: Optional TSV file containing segmentation parameters in MoBIE format + :param int distance_threshold: Distance in micrometer to check for neighbors + :param int neighbor_threshold: Minimal number of neighbors for filtering + :param int min_size: Minimal number of pixels for filtering small instances + :param str output_key: Output key for postprocessed segmentation + """ + if tsv_table is not None: + n_pixels = tsv_table["n_pixels"].to_list() + label_ids = tsv_table["label_id"].to_list() + centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"])) + n_ids = len(label_ids) + + # filter out cells smaller than min_size + if min_size is not None: + min_size_label_ids = [l for (l,n) in zip(label_ids, n_pixels) if n <= min_size] + centroids = [c for (c,l) in zip(centroids, label_ids) if l not in min_size_label_ids] + label_ids = [int(l) for l in label_ids if l not in min_size_label_ids] + + coordinates = np.array(centroids) + label_ids = np.array(label_ids) + + else: + segmentation, n_ids, _ = vigra.analysis.relabelConsecutive(segmentation[:], start_label=1, keep_zeros=True) + props = measure.regionprops(segmentation) + coordinates = np.array([prop.centroid for prop in props]) + label_ids = np.unique(segmentation)[1:] # Calculate pairwise distances and convert to a square matrix dist_matrix = distance.pdist(coordinates) @@ -22,13 +61,38 @@ def filter_isolated_objects(segmentation, distance_threshold=15, neighbor_thresh # Sum each row to count neighbors neighbor_counts = sparse_matrix.sum(axis=1) - seg_ids = np.unique(segmentation)[1:] filter_mask = np.array(neighbor_counts < neighbor_threshold).squeeze() - filter_ids = seg_ids[filter_mask] + filter_ids = label_ids[filter_mask] + + shape = segmentation.shape + block_shape=(128,128,128) + chunks=(128,128,128) + + blocking = nt.blocking([0] * len(shape), shape, block_shape) + + output = open_file(output_path, mode="a") + + output_dataset = output.create_dataset( + output_key, shape=shape, dtype=segmentation.dtype, + chunks=chunks, compression="gzip" + ) + + def filter_chunk(block_id): + """ + Set all points within a chunk to zero if they match filter IDs. + """ + block = blocking.getBlock(block_id) + volume_index = tuple(slice(beg, end) for beg, end in zip(block.begin, block.end)) + data = segmentation[volume_index] + data[np.isin(data, filter_ids)] = 0 + output_dataset[volume_index] = data + + # Limit the number of cores for parallelization. + n_threads = min(16, mp.cpu_count()) - seg_filtered = segmentation.copy() - seg_filtered[np.isin(seg_filtered, filter_ids)] = 0 + with futures.ThreadPoolExecutor(n_threads) as filter_pool: + list(tqdm(filter_pool.map(filter_chunk, range(blocking.numberOfBlocks)), total=blocking.numberOfBlocks)) - seg_filtered, n_ids_filtered, _ = vigra.analysis.relabelConsecutive(seg_filtered, start_label=1, keep_zeros=True) + seg_filtered, n_ids_filtered, _ = parallel.relabel_consecutive(output_dataset, start_label=1, keep_zeros=True, block_shape=(128,128,128)) return seg_filtered, n_ids, n_ids_filtered diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index ca56d09..9cdc5f6 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -1,5 +1,6 @@ import multiprocessing as mp import os +import sys import warnings from concurrent import futures @@ -10,6 +11,7 @@ import vigra import torch import z5py +import zarr import json from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper @@ -18,6 +20,10 @@ from torch_em.util import load_model from torch_em.util.prediction import predict_with_halo from tqdm import tqdm +from inspect import getsourcefile + +sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(getsourcefile(lambda:0)))), "scripts", "prediction")) +import upload_to_s3 """ Prediction using distance U-Net. @@ -43,7 +49,7 @@ def ndim(self): return self._volume.ndim - 1 -def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None): +def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=1, slurm_task_id=0, mean=None, std=None, s3=None): with warnings.catch_warnings(): warnings.simplefilter("ignore") if os.path.isdir(model_path): @@ -56,6 +62,9 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo if input_key is None: input_ = imageio.imread(input_path) + elif s3 is not None: + with zarr.open(input_path, mode="r") as f: + input_ = f[input_key] else: input_ = open_file(input_path, "r")[input_key] @@ -138,7 +147,7 @@ def postprocess(x): return original_shape -def find_mask(input_path, input_key, output_folder): +def find_mask(input_path, input_key, output_folder, s3=None): mask_path = os.path.join(output_folder, "mask.zarr") f = z5py.File(mask_path, "a") @@ -149,6 +158,10 @@ def find_mask(input_path, input_key, output_folder): if input_key is None: raw = imageio.imread(input_path) chunks = (64, 64, 64) + elif s3 is not None: + with zarr.open(input_path, mode="r") as fin: + raw = fin[input_key] + chunks = raw.chunks else: fin = open_file(input_path, "r") raw = fin[input_key] @@ -243,7 +256,10 @@ def write_block(block_id): tp.map(write_block, range(blocking.numberOfBlocks)) -def calc_mean_and_std(input_path, input_key, output_folder): +def calc_mean_and_std( + input_path, input_key, output_folder, + s3=None, + ): """ Calculate mean and standard deviation of full volume. Parameters are saved in 'mean_std.json' within the output folder. @@ -254,6 +270,9 @@ def calc_mean_and_std(input_path, input_key, output_folder): if input_key is None: input_ = imageio.imread(input_path) + elif s3 is not None: + with zarr.open(input_path, mode="r") as f: + input_ = f[input_key] else: input_ = open_file(input_path, "r")[input_key] @@ -267,6 +286,7 @@ def calc_mean_and_std(input_path, input_key, output_folder): with open(json_file, "w") as f: json.dump(ddict, f) + def run_unet_prediction( input_path, input_key, output_folder, model_path, @@ -288,32 +308,63 @@ def run_unet_prediction( def run_unet_prediction_preprocess_slurm( input_path, input_key, output_folder, + s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, ): """ Pre-processing for the parallel prediction with U-Net models. Masks are stored in mask.zarr in the output folder. The mean and standard deviation are precomputed for later usage during prediction - and stored in a JSON file within the output folder as mean_std.json + and stored in a JSON file within the output folder as mean_std.json. """ - find_mask(input_path, input_key, output_folder) - calc_mean_and_std(input_path, input_key, output_folder) + if s3 is not None: + bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + + input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + + if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): + find_mask(input_path, input_key, output_folder, s3=s3) + + calc_mean_and_std(input_path, input_key, output_folder, s3=s3) + def run_unet_prediction_slurm( input_path, input_key, output_folder, model_path, scale=None, block_shape=None, halo=None, prediction_instances=1, + s3=None, s3_bucket_name=None, s3_service_endpoint=None, s3_credentials=None, ): + """ + Run prediction of distance U-Net for data stored locally or on an S3 bucket. + + :param str input_path: File path to input data + :param str input_key: Input key for data in ome.zarr format + :param str output_folder: Output folder for prediction.zarr + :param str model_path: File path to distance U-Net model + :param float scale: + :param tuple block_shape: + :param tuple halo: + :param int prediction_instances: Number of workers for parallel computation within slurm array + :param bool s3: Flag for accessing data on S3 bucket + :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported + :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported + :param str s3_credentials: Path to file containing S3 credentials + """ os.makedirs(output_folder, exist_ok=True) prediction_instances = int(prediction_instances) slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") + if s3 is not None: + bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + + input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + if slurm_task_id is not None: slurm_task_id = int(slurm_task_id) else: raise ValueError("The SLURM_ARRAY_TASK_ID is not set. Ensure that you are using the '-a' option with SBATCH.") if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): - find_mask(input_path, input_key, output_folder) + find_mask(input_path, input_key, output_folder, s3=s3) # get pre-computed mean and standard deviation of full volume from JSON file if os.path.isfile(os.path.join(output_folder, "mean_std.json")): @@ -328,9 +379,10 @@ def run_unet_prediction_slurm( original_shape = prediction_impl( input_path, input_key, output_folder, model_path, scale, block_shape, halo, prediction_instances=prediction_instances, slurm_task_id=slurm_task_id, - mean=mean, std=std, + mean=mean, std=std, s3=s3, ) + # does NOT need GPU, FIXME: only run on CPU def run_unet_segmentation_slurm(output_folder, min_size): min_size = int(min_size) diff --git a/scripts/extract_block.py b/scripts/extract_block.py index f248dc2..0015b3b 100644 --- a/scripts/extract_block.py +++ b/scripts/extract_block.py @@ -1,10 +1,14 @@ import os +import sys import argparse import numpy as np import z5py import zarr -import s3fs +from inspect import getsourcefile + +sys.path.append(os.path.join(os.path.dirname(getsourcefile(lambda:0)), "prediction")) +import upload_to_s3 """ This script extracts data around an input center coordinate in a given ROI halo. @@ -18,7 +22,10 @@ """ -def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3): +def main( + input_file, output_dir, coords, input_key, resolution, roi_halo, + s3, s3_credentials, s3_bucket_name, s3_service_endpoint, + ): """ :param str input_file: File path to input folder in n5 format @@ -28,6 +35,9 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3): :param str coords: Center coordinates of extracted 3D volume in format 'x,y,z' :param str roi_halo: ROI halo of extracted 3D volume in format 'x,y,z' :param bool s3: Flag for using an S3 bucket + :param str s3_credentials: Path to file containing S3 credentials + :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported + :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported """ coords = [int(r) for r in coords.split(",")] @@ -61,33 +71,18 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3): roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: + bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - # Define S3 bucket and OME-Zarr dataset path - - bucket_name = "cochlea-lightsheet" - zarr_path = f"{bucket_name}/{input_file}" - - # Create an S3 filesystem - fs = s3fs.S3FileSystem( - client_kwargs={"endpoint_url": "https://s3.fs.gwdg.de"}, - anon=False - ) + s3_path, fs = upload_to_s3.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) - if not fs.exists(zarr_path): - print("Error: Path does not exist!") - - # Open the OME-Zarr dataset - store = zarr.storage.FSStore(zarr_path, fs=fs) - print(f"Opening file {zarr_path} from the S3 bucket.") - - with zarr.open(store, mode="r") as f: + with zarr.open(s3_path, mode="r") as f: raw = f[input_key][roi] else: - with z5py.File(input_file, "r") as f: + with zarr.open(input_file, mode="r") as f: raw = f[input_key][roi] - with z5py.File(output_file, "w") as f_out: + with zarr.open(output_file, mode="w") as f_out: f_out.create_dataset("raw", data=raw, compression="gzip") if __name__ == "__main__": @@ -103,8 +98,15 @@ def main(input_file, output_dir, input_key, resolution, coords, roi_halo, s3): parser.add_argument('-r', "--resolution", type=float, default=0.38, help="Resolution of input in micrometer") parser.add_argument("--roi_halo", type=str, default="128,128,64", help="ROI halo around center coordinate in format 'x,y,z'") + parser.add_argument("--s3", action="store_true", help="Use S3 bucket") + parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials") + parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name") + parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint") args = parser.parse_args() - main(args.input, args.output, args.input_key, args.resolution, args.coord, args.roi_halo, args.s3) + main( + args.input, args.output, args.coord, args.input_key, args.resolution, args.roi_halo, + args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + ) diff --git a/scripts/prediction/count_cells.py b/scripts/prediction/count_cells.py index b92f284..dea7a2d 100644 --- a/scripts/prediction/count_cells.py +++ b/scripts/prediction/count_cells.py @@ -2,23 +2,44 @@ import os import sys +import zarr + from elf.parallel import unique from elf.io import open_file +import upload_to_s3 sys.path.append("../..") def main(): parser = argparse.ArgumentParser() - parser.add_argument("-o", "--output_folder", type=str, required=True, help="Output directory containing segmentation.zarr") + parser.add_argument("-o", "--output_folder", type=str, default=None, help="Output directory containing segmentation.zarr") + + parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="Input key for data in input file") parser.add_argument("-m", "--min_size", type=int, default=1000, help="Minimal number of voxel size for counting object") + + parser.add_argument("--s3_input", default=None, help="Input file path on S3 bucket") + parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials") + parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name") + parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint") + args = parser.parse_args() - seg_path = os.path.join(args.output_folder, "segmentation.zarr") - seg_key = "segmentation" + if args.output_folder is not None: + seg_path = os.path.join(args.output_folder, "segmentation.zarr") + elif args.s3_input is None: + raise ValueError("Either provide an output_folder containing 'segmentation.zarr' or an S3 input.") + + if args.s3_input is not None: + bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) + + s3_path, fs = upload_to_s3.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + with zarr.open(s3_path, mode="r") as f: + dataset = f[args.input_key] - file = open_file(seg_path, mode='r') - dataset = file[seg_key] + else: + segmentation = open_file(seg_path, mode='r') + dataset = segmentation[args.input_key] ids, counts = unique(dataset, return_counts=True) diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 2d1a778..bc186cf 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -2,7 +2,10 @@ import os import sys -import z5py +import pandas as pd +import zarr + +import upload_to_s3 sys.path.append("../..") @@ -10,25 +13,57 @@ def main(): from flamingo_tools.segmentation import filter_isolated_objects - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Script for postprocessing segmentation data in zarr format. Either locally or on an S3 bucket.") + parser.add_argument("-o", "--output_folder", required=True) + + parser.add_argument("-t", "--tsv", default=None, help="TSV-file in MoBIE format which contains information about the segmentation") + parser.add_argument('-k', "--input_key", type=str, default="segmentation", help="Input key for data in input file") + parser.add_argument("--output_key", type=str, default="segmentation_postprocessed", help="Output key for data in input file") + + parser.add_argument("--s3_input", default=None, help="Input file path on S3 bucket") + parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials. Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported") + parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name. Optional if BUCKET_NAME was exported") + parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported") + + parser.add_argument("--min_size", type=int, default=None, help="Minimal number of voxel size for counting object") + parser.add_argument("--distance_threshold", type=int, default=15, help="Distance in micrometer to check for neighbors") + parser.add_argument("--neighbor_threshold", type=int, default=5, help="Minimal number of neighbors for filtering") + args = parser.parse_args() seg_path = os.path.join(args.output_folder, "segmentation.zarr") - seg_key = "segmentation" - with z5py.File(seg_path, "r") as f: - segmentation = f[seg_key][:] + tsv_table=None + + if args.s3_input is not None: + bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) + + s3_path, fs = upload_to_s3.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + with zarr.open(s3_path, mode="r") as f: + segmentation = f[args.input_key] + + if args.tsv is not None: + tsv_path, fs = upload_to_s3.get_s3_path(args.tsv, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + with fs.open(tsv_path, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") + + else: + with zarr.open(seg_path, mode="r") as f: + segmentation = f[args.input_key] - seg_filtered, n_pre, n_post = filter_isolated_objects(segmentation) + if args.tsv is not None: + with open(args.tsv, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") - with z5py.File(seg_path, "a") as f: - chunks = f[seg_key].chunks - f.create_dataset( - "segmentation_postprocessed", data=seg_filtered, compression="gzip", - chunks=chunks, dtype=seg_filtered.dtype + seg_filtered, n_pre, n_post = filter_isolated_objects( + segmentation, output_path=seg_path, tsv_table=tsv_table, min_size=args.min_size, + distance_threshold=args.distance_threshold, neighbor_threshold=args.neighbor_threshold, + output_key=args.output_key, ) + print(f"Number of pre-filtered objects: {n_pre}\nNumber of post-filtered objects: {n_post}") if __name__ == "__main__": main() diff --git a/scripts/prediction/upload_to_s3.py b/scripts/prediction/upload_to_s3.py index 061526a..96b75c2 100644 --- a/scripts/prediction/upload_to_s3.py +++ b/scripts/prediction/upload_to_s3.py @@ -1,6 +1,7 @@ import os import s3fs +import zarr from mobie.metadata import add_remote_project_metadata from tqdm import tqdm @@ -13,6 +14,50 @@ # For MoBIE: # https://s3.gwdg.de/incucyte-general/lightsheet +def check_s3_credentials(bucket_name, service_endpoint, credentials): + """ + Check if S3 parameter and credentials were set either as a function input or were exported as environment variables. + """ + if bucket_name is None: + bucket_name = os.getenv('BUCKET_NAME') + if bucket_name is None: + raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") + + if service_endpoint is None: + service_endpoint = os.getenv('SERVICE_ENDPOINT') + if service_endpoint is None: + raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") + + if credentials is None: + access_key = os.getenv('AWS_ACCESS_KEY_ID') + secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') + if access_key is None: + raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") + if secret_key is None: + raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") + + return bucket_name, service_endpoint, credentials + + +def get_s3_path( + input_path, + bucket_name, service_endpoint, + credential_file=None, +): + """ + Get S3 path for a file or folder and file system based on S3 parameters and credentials. + """ + fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) + + zarr_path=f"{bucket_name}/{input_path}" + + if not fs.exists(zarr_path): + print(f"Error: S3 path {zarr_path} does not exist!") + + s3_path = zarr.storage.FSStore(zarr_path, fs=fs) + + return s3_path, fs + def read_s3_credentials(credential_file): key, secret = None, None @@ -28,6 +73,11 @@ def read_s3_credentials(credential_file): def create_s3_target(url, anon=False, credential_file=None): + """ + Create file system for S3 bucket based on a service endpoint and an optional credential file. + If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables + AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. + """ client_kwargs = {"endpoint_url": url} if credential_file is not None: key, secret = read_s3_credentials(credential_file) From eb386469848220265fbf80af13396293f8d3b4ba Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 10 Apr 2025 17:54:24 +0200 Subject: [PATCH 2/7] Updated packages for environment --- environment.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/environment.yaml b/environment.yaml index d29e3f8..2aedfb7 100644 --- a/environment.yaml +++ b/environment.yaml @@ -10,5 +10,7 @@ dependencies: - scikit-image - pybdv - pytorch + - s3fs - torch_em - z5py + - zarr From 760d6cabb99560f06b6c804746f64d2aa971a862 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 10 Apr 2025 17:59:26 +0200 Subject: [PATCH 3/7] Moved S3 utils --- flamingo_tools/s3_utils.py | 112 ++++++++++++++++++ .../segmentation/unet_prediction.py | 14 +-- scripts/extract_block.py | 11 +- scripts/prediction/count_cells.py | 9 +- scripts/prediction/postprocess_seg.py | 9 +- scripts/prediction/upload_to_s3.py | 105 +--------------- 6 files changed, 132 insertions(+), 128 deletions(-) create mode 100644 flamingo_tools/s3_utils.py diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py new file mode 100644 index 0000000..9948eb0 --- /dev/null +++ b/flamingo_tools/s3_utils.py @@ -0,0 +1,112 @@ +import os + +import s3fs +import zarr + +from mobie.metadata import add_remote_project_metadata +from tqdm import tqdm + +# Using incucyte s3 as a temporary measure. +MOBIE_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie" +SERVICE_ENDPOINT = "https://s3.gwdg.de/" +BUCKET_NAME = "incucyte-general/lightsheet" + +# For MoBIE: +# https://s3.gwdg.de/incucyte-general/lightsheet + +def check_s3_credentials(bucket_name, service_endpoint, credentials): + """ + Check if S3 parameter and credentials were set either as a function input or were exported as environment variables. + """ + if bucket_name is None: + bucket_name = os.getenv('BUCKET_NAME') + if bucket_name is None: + raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") + + if service_endpoint is None: + service_endpoint = os.getenv('SERVICE_ENDPOINT') + if service_endpoint is None: + raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") + + if credentials is None: + access_key = os.getenv('AWS_ACCESS_KEY_ID') + secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') + if access_key is None: + raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") + if secret_key is None: + raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") + + return bucket_name, service_endpoint, credentials + + +def get_s3_path( + input_path, + bucket_name, service_endpoint, + credential_file=None, +): + """ + Get S3 path for a file or folder and file system based on S3 parameters and credentials. + """ + fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) + + zarr_path=f"{bucket_name}/{input_path}" + + if not fs.exists(zarr_path): + print(f"Error: S3 path {zarr_path} does not exist!") + + s3_path = zarr.storage.FSStore(zarr_path, fs=fs) + + return s3_path, fs + + +def read_s3_credentials(credential_file): + key, secret = None, None + with open(credential_file) as f: + for line in f: + if line.startswith("aws_access_key_id"): + key = line.rstrip("\n").strip().split(" ")[-1] + if line.startswith("aws_secret_access_key"): + secret = line.rstrip("\n").strip().split(" ")[-1] + if key is None or secret is None: + raise ValueError(f"Invalid credential file {credential_file}") + return key, secret + + +def create_s3_target(url, anon=False, credential_file=None): + """ + Create file system for S3 bucket based on a service endpoint and an optional credential file. + If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables + AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. + """ + client_kwargs = {"endpoint_url": url} + if credential_file is not None: + key, secret = read_s3_credentials(credential_file) + fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) + else: + fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) + return fs + + +def remote_metadata(): + add_remote_project_metadata(MOBIE_FOLDER, BUCKET_NAME, SERVICE_ENDPOINT) + + +def upload_data(): + target = create_s3_target( + SERVICE_ENDPOINT, + credential_file="./credentials.incucyte" + ) + to_upload = [] + for root, dirs, files in os.walk(MOBIE_FOLDER): + dirs.sort() + for ff in files: + if ff.endswith(".xml"): + to_upload.append(os.path.join(root, ff)) + + print("Uploading", len(to_upload), "files to") + + for path in tqdm(to_upload): + rel_path = os.path.relpath(path, MOBIE_FOLDER) + target.put( + path, os.path.join(BUCKET_NAME, rel_path) + ) \ No newline at end of file diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 9cdc5f6..014d432 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -20,10 +20,8 @@ from torch_em.util import load_model from torch_em.util.prediction import predict_with_halo from tqdm import tqdm -from inspect import getsourcefile -sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(getsourcefile(lambda:0)))), "scripts", "prediction")) -import upload_to_s3 +import flamingo_tools.s3_utils as s3_utils """ Prediction using distance U-Net. @@ -97,7 +95,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo # Compute the global mean and standard deviation. n_threads = min(16, mp.cpu_count()) mean, std = parallel.mean_and_std( - input_, block_shape=block_shape, n_threads=n_threads, verbose=True, + input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True, mask=image_mask ) print("Mean and standard deviation computed for the full volume:") @@ -317,9 +315,9 @@ def run_unet_prediction_preprocess_slurm( and stored in a JSON file within the output folder as mean_std.json. """ if s3 is not None: - bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): find_mask(input_path, input_key, output_folder, s3=s3) @@ -354,9 +352,9 @@ def run_unet_prediction_slurm( slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") if s3 is not None: - bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - input_path, fs = upload_to_s3.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) if slurm_task_id is not None: slurm_task_id = int(slurm_task_id) diff --git a/scripts/extract_block.py b/scripts/extract_block.py index 0015b3b..d8fc376 100644 --- a/scripts/extract_block.py +++ b/scripts/extract_block.py @@ -1,14 +1,9 @@ import os -import sys import argparse import numpy as np -import z5py import zarr -from inspect import getsourcefile - -sys.path.append(os.path.join(os.path.dirname(getsourcefile(lambda:0)), "prediction")) -import upload_to_s3 +import flamingo_tools.s3_utils as s3_utils """ This script extracts data around an input center coordinate in a given ROI halo. @@ -71,9 +66,9 @@ def main( roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: - bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - s3_path, fs = upload_to_s3.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) with zarr.open(s3_path, mode="r") as f: raw = f[input_key][roi] diff --git a/scripts/prediction/count_cells.py b/scripts/prediction/count_cells.py index dea7a2d..8bfef8d 100644 --- a/scripts/prediction/count_cells.py +++ b/scripts/prediction/count_cells.py @@ -1,15 +1,12 @@ import argparse import os -import sys import zarr from elf.parallel import unique from elf.io import open_file -import upload_to_s3 - -sys.path.append("../..") +import flamingo_tools.s3_utils as s3_utils def main(): parser = argparse.ArgumentParser() @@ -31,9 +28,9 @@ def main(): raise ValueError("Either provide an output_folder containing 'segmentation.zarr' or an S3 input.") if args.s3_input is not None: - bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) - s3_path, fs = upload_to_s3.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) with zarr.open(s3_path, mode="r") as f: dataset = f[args.input_key] diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index bc186cf..ea9f73d 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -5,10 +5,9 @@ import pandas as pd import zarr -import upload_to_s3 - sys.path.append("../..") +import flamingo_tools.s3_utils as s3_utils def main(): from flamingo_tools.segmentation import filter_isolated_objects @@ -38,14 +37,14 @@ def main(): tsv_table=None if args.s3_input is not None: - bucket_name, service_endpoint, credentials = upload_to_s3.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) - s3_path, fs = upload_to_s3.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) with zarr.open(s3_path, mode="r") as f: segmentation = f[args.input_key] if args.tsv is not None: - tsv_path, fs = upload_to_s3.get_s3_path(args.tsv, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") diff --git a/scripts/prediction/upload_to_s3.py b/scripts/prediction/upload_to_s3.py index 96b75c2..d557381 100644 --- a/scripts/prediction/upload_to_s3.py +++ b/scripts/prediction/upload_to_s3.py @@ -6,118 +6,21 @@ from mobie.metadata import add_remote_project_metadata from tqdm import tqdm +import flamingo_tools.s3_utils as s3_utils + # Using incucyte s3 as a temporary measure. MOBIE_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie" SERVICE_ENDPOINT = "https://s3.gwdg.de/" BUCKET_NAME = "incucyte-general/lightsheet" -# For MoBIE: -# https://s3.gwdg.de/incucyte-general/lightsheet - -def check_s3_credentials(bucket_name, service_endpoint, credentials): - """ - Check if S3 parameter and credentials were set either as a function input or were exported as environment variables. - """ - if bucket_name is None: - bucket_name = os.getenv('BUCKET_NAME') - if bucket_name is None: - raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") - - if service_endpoint is None: - service_endpoint = os.getenv('SERVICE_ENDPOINT') - if service_endpoint is None: - raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") - - if credentials is None: - access_key = os.getenv('AWS_ACCESS_KEY_ID') - secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') - if access_key is None: - raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") - if secret_key is None: - raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") - - return bucket_name, service_endpoint, credentials - - -def get_s3_path( - input_path, - bucket_name, service_endpoint, - credential_file=None, -): - """ - Get S3 path for a file or folder and file system based on S3 parameters and credentials. - """ - fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) - - zarr_path=f"{bucket_name}/{input_path}" - - if not fs.exists(zarr_path): - print(f"Error: S3 path {zarr_path} does not exist!") - - s3_path = zarr.storage.FSStore(zarr_path, fs=fs) - - return s3_path, fs - - -def read_s3_credentials(credential_file): - key, secret = None, None - with open(credential_file) as f: - for line in f: - if line.startswith("aws_access_key_id"): - key = line.rstrip("\n").strip().split(" ")[-1] - if line.startswith("aws_secret_access_key"): - secret = line.rstrip("\n").strip().split(" ")[-1] - if key is None or secret is None: - raise ValueError(f"Invalid credential file {credential_file}") - return key, secret - - -def create_s3_target(url, anon=False, credential_file=None): - """ - Create file system for S3 bucket based on a service endpoint and an optional credential file. - If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. - """ - client_kwargs = {"endpoint_url": url} - if credential_file is not None: - key, secret = read_s3_credentials(credential_file) - fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) - else: - fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) - return fs - - -def remote_metadata(): - add_remote_project_metadata(MOBIE_FOLDER, BUCKET_NAME, SERVICE_ENDPOINT) - - -def upload_data(): - target = create_s3_target( - SERVICE_ENDPOINT, - credential_file="./credentials.incucyte" - ) - to_upload = [] - for root, dirs, files in os.walk(MOBIE_FOLDER): - dirs.sort() - for ff in files: - if ff.endswith(".xml"): - to_upload.append(os.path.join(root, ff)) - - print("Uploading", len(to_upload), "files to") - - for path in tqdm(to_upload): - rel_path = os.path.relpath(path, MOBIE_FOLDER) - target.put( - path, os.path.join(BUCKET_NAME, rel_path) - ) - +# FIXME: Complete overhaul with flexible folder, service endpoint, bucket name # FIXME: access via s3 is not working due to permission issues. # Maybe this is not working due to bdv fileformat?! # Make an issue in MoBIE. def main(): # remote_metadata() - upload_data() + s3_utils.upload_data() if __name__ == "__main__": From bb41c548ece4e4494eecba9b6537c052d4ae2acf Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 10 Apr 2025 18:00:39 +0200 Subject: [PATCH 4/7] Expand segmentation table with distance to nearest neighbors --- flamingo_tools/segmentation/postprocessing.py | 27 +++++++++ scripts/prediction/expand_seg_table.py | 60 +++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 scripts/prediction/expand_seg_table.py diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 46b6bcd..1264956 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -7,11 +7,38 @@ from scipy.spatial import distance from scipy.sparse import csr_matrix from tqdm import tqdm +from sklearn.neighbors import NearestNeighbors import elf.parallel as parallel from elf.io import open_file import nifty.tools as nt +def distance_nearest_neighbors(tsv_table, n_neighbors=10, expand_table=True): + """ + Calculate average distance of n nearest neighbors. + + :param DataFrame tsv_table: + :param int n_neighbors: Number of nearest neighbors + :param bool expand_table: Flag for expanding DataFrame + :returns: List of average distances + :rtype: list + """ + centroids = list(zip(tsv_table["anchor_x"], tsv_table["anchor_y"], tsv_table["anchor_z"])) + + coordinates = np.array(centroids) + + # nearest neighbor is always itself, so n_neighbors+=1 + nbrs = NearestNeighbors(n_neighbors=n_neighbors+1).fit(coordinates) + distances, indices = nbrs.kneighbors(coordinates) + + # Average distance to nearest neighbors + distance_avg = [sum(d) / len(d) for d in distances[:, 1:]] + + if expand_table: + tsv_table['distance_nn'+str(n_neighbors)] = distance_avg + + return distance_avg + def filter_isolated_objects( segmentation, output_path, tsv_table=None, distance_threshold=15, neighbor_threshold=5, min_size=1000, diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py new file mode 100644 index 0000000..94c15ff --- /dev/null +++ b/scripts/prediction/expand_seg_table.py @@ -0,0 +1,60 @@ +import argparse + +import pandas as pd + +import flamingo_tools.segmentation.postprocessing as postprocessing +import flamingo_tools.s3_utils as s3_utils + +def main( + in_path, out_path, n_neighbors=None, + s3=False, s3_credentials=None, s3_bucket_name=None, s3_service_endpoint=None, + ): + """ + + :param str input_file: Path to table in TSV format + :param str out_path: Path to save output + :param bool s3: Flag for using an S3 bucket + :param str s3_credentials: Path to file containing S3 credentials + :param str s3_bucket_name: S3 bucket name. Optional if BUCKET_NAME has been exported + :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported + """ + if s3: + bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) + tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + with fs.open(tsv_path, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") + else: + with open(in_path, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") + + if n_neighbors is not None: + nn_list = [int(n) for n in n_neighbors.split(",")] + for n_neighbor in nn_list: + if n_neighbor >= len(tsv_table): + raise ValueError(f"Number of neighbors: {n_neighbor} exceeds number of elements in dataframe: {len(tsv_table)}.") + + _ = postprocessing.distance_nearest_neighbors(tsv_table=tsv_table, n_neighbors=n_neighbor, expand_table=True) + + tsv_table.to_csv(out_path, sep="\t") + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Script for expanding the segmentation table of MoBIE with additonal parameters. Either locally or on an S3 bucket.") + + parser.add_argument("-i", "--input", required=True) + parser.add_argument("-o", "--output", required=True) + + parser.add_argument("--n_neighbors", default=None, help="Value(s) for number of nearest neighbors in format 'n1,n2,...,nx'. New columns contain the average distance to nearest neighbors.") + + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket") + parser.add_argument("--s3_credentials", default=None, help="Input file containing S3 credentials. Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported") + parser.add_argument("--s3_bucket_name", default=None, help="S3 bucket name. Optional if BUCKET_NAME was exported") + parser.add_argument("--s3_service_endpoint", default=None, help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported") + + args = parser.parse_args() + + main( + args.input, args.output, args.n_neighbors, + args.s3, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + ) From dad0b3127db55a653843ecfacdbb7e51ea7ddc0d Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 10 Apr 2025 18:11:31 +0200 Subject: [PATCH 5/7] Removed mobie dependencies from package --- flamingo_tools/s3_utils.py | 5 ----- scripts/prediction/upload_to_s3.py | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index 9948eb0..09e08f2 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -3,7 +3,6 @@ import s3fs import zarr -from mobie.metadata import add_remote_project_metadata from tqdm import tqdm # Using incucyte s3 as a temporary measure. @@ -87,10 +86,6 @@ def create_s3_target(url, anon=False, credential_file=None): return fs -def remote_metadata(): - add_remote_project_metadata(MOBIE_FOLDER, BUCKET_NAME, SERVICE_ENDPOINT) - - def upload_data(): target = create_s3_target( SERVICE_ENDPOINT, diff --git a/scripts/prediction/upload_to_s3.py b/scripts/prediction/upload_to_s3.py index d557381..2f6eb44 100644 --- a/scripts/prediction/upload_to_s3.py +++ b/scripts/prediction/upload_to_s3.py @@ -23,5 +23,9 @@ def main(): s3_utils.upload_data() +def remote_metadata(): + add_remote_project_metadata(MOBIE_FOLDER, BUCKET_NAME, SERVICE_ENDPOINT) + + if __name__ == "__main__": main() From 43941341edd42519dfdca20a60850385eb38a3d5 Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Thu, 10 Apr 2025 20:45:51 +0200 Subject: [PATCH 6/7] Fixed issue for setting chunks --- flamingo_tools/segmentation/unet_prediction.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 014d432..33c4217 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -60,11 +60,14 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo if input_key is None: input_ = imageio.imread(input_path) + chunks = (64, 64, 64) elif s3 is not None: with zarr.open(input_path, mode="r") as f: input_ = f[input_key] + chunks = input_.chunks() else: input_ = open_file(input_path, "r")[input_key] + chunks = (64, 64, 64) if scale is None or scale == 1: original_shape = None @@ -95,7 +98,7 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo # Compute the global mean and standard deviation. n_threads = min(16, mp.cpu_count()) mean, std = parallel.mean_and_std( - input_, block_shape=tuple([2* i for i in input_.chunks]), n_threads=n_threads, verbose=True, + input_, block_shape=tuple([2* i for i in chunks]), n_threads=n_threads, verbose=True, mask=image_mask ) print("Mean and standard deviation computed for the full volume:") @@ -163,7 +166,7 @@ def find_mask(input_path, input_key, output_folder, s3=None): else: fin = open_file(input_path, "r") raw = fin[input_key] - chunks = raw.chunks + chunks = (64, 64, 64) block_shape = tuple(2 * ch for ch in chunks) blocking = nt.blocking([0, 0, 0], raw.shape, block_shape) From 6bf43e3659f33d8f8f5924eea8c4db384354621d Mon Sep 17 00:00:00 2001 From: Martin Schilling Date: Fri, 11 Apr 2025 15:12:46 +0200 Subject: [PATCH 7/7] Improved S3 functions, fixed issues --- flamingo_tools/s3_utils.py | 70 ++++++++++--------- .../segmentation/unet_prediction.py | 45 ++++++------ scripts/extract_block.py | 4 +- scripts/prediction/count_cells.py | 4 +- scripts/prediction/expand_seg_table.py | 3 +- scripts/prediction/postprocess_seg.py | 6 +- 6 files changed, 66 insertions(+), 66 deletions(-) diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index 09e08f2..c371af4 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -3,49 +3,74 @@ import s3fs import zarr -from tqdm import tqdm +""" +This script contains utility functions for processing data located on an S3 storage. +The upload of data to the storage system should be performed with 'rclone'. +""" -# Using incucyte s3 as a temporary measure. -MOBIE_FOLDER = "/mnt/lustre-emmy-hdd/projects/nim00007/data/moser/lightsheet/mobie" +# Dedicated bucket for cochlea lightsheet project +MOBIE_FOLDER = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" SERVICE_ENDPOINT = "https://s3.gwdg.de/" -BUCKET_NAME = "incucyte-general/lightsheet" +BUCKET_NAME = "cochlea-lightsheet" + +DEFAULT_CREDENTIALS = os.path.expanduser("~/.aws/credentials") # For MoBIE: # https://s3.gwdg.de/incucyte-general/lightsheet -def check_s3_credentials(bucket_name, service_endpoint, credentials): +def check_s3_credentials(bucket_name, service_endpoint, credential_file): """ Check if S3 parameter and credentials were set either as a function input or were exported as environment variables. """ if bucket_name is None: bucket_name = os.getenv('BUCKET_NAME') if bucket_name is None: - raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") + if BUCKET_NAME in globals(): + bucket_name = BUCKET_NAME + else: + raise ValueError("Provide a bucket name for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_bucket_name \nexport BUCKET_NAME=") if service_endpoint is None: service_endpoint = os.getenv('SERVICE_ENDPOINT') if service_endpoint is None: - raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") + if SERVICE_ENDPOINT in globals(): + service_endpoint = SERVICE_ENDPOINT + else: + raise ValueError("Provide a service endpoint for accessing S3 data.\nEither by using an optional argument or exporting an environment variable:\n--s3_service_endpoint \nexport SERVICE_ENDPOINT=") - if credentials is None: + if credential_file is None: access_key = os.getenv('AWS_ACCESS_KEY_ID') secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') + + # check for default credentials if no credential_file is provided if access_key is None: - raise ValueError("Either provide a credential file as an optional argument or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") + if os.path.isfile(DEFAULT_CREDENTIALS): + access_key, _ = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS) + else: + raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export an access key as an environment variable:\nexport AWS_ACCESS_KEY_ID=") if secret_key is None: - raise ValueError("Either provide a credential file as an optional argument or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") + # check for default credentials + if os.path.isfile(DEFAULT_CREDENTIALS): + _, secret_key = read_s3_credentials(credential_file=DEFAULT_CREDENTIALS) + else: + raise ValueError(f"Either provide a credential file as an optional argument, have credentials at '{DEFAULT_CREDENTIALS}', or export a secret access key as an environment variable:\nexport AWS_SECRET_ACCESS_KEY=") - return bucket_name, service_endpoint, credentials + else: + # check validity of credential file + _, _ = read_s3_credentials(credential_file=credential_file) + return bucket_name, service_endpoint, credential_file def get_s3_path( input_path, - bucket_name, service_endpoint, + bucket_name=None, service_endpoint=None, credential_file=None, ): """ Get S3 path for a file or folder and file system based on S3 parameters and credentials. """ + bucket_name, service_endpoint, credential_file = check_s3_credentials(bucket_name, service_endpoint, credential_file) + fs = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) zarr_path=f"{bucket_name}/{input_path}" @@ -84,24 +109,3 @@ def create_s3_target(url, anon=False, credential_file=None): else: fs = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) return fs - - -def upload_data(): - target = create_s3_target( - SERVICE_ENDPOINT, - credential_file="./credentials.incucyte" - ) - to_upload = [] - for root, dirs, files in os.walk(MOBIE_FOLDER): - dirs.sort() - for ff in files: - if ff.endswith(".xml"): - to_upload.append(os.path.join(root, ff)) - - print("Uploading", len(to_upload), "files to") - - for path in tqdm(to_upload): - rel_path = os.path.relpath(path, MOBIE_FOLDER) - target.put( - path, os.path.join(BUCKET_NAME, rel_path) - ) \ No newline at end of file diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 33c4217..35298bb 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -12,6 +12,7 @@ import torch import z5py import zarr +import tifffile import json from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper @@ -59,15 +60,18 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo image_mask = z5py.File(mask_path, "r")["mask"] if input_key is None: - input_ = imageio.imread(input_path) - chunks = (64, 64, 64) - elif s3 is not None: + try: + input_ = tifffile.memmap(input_path, mode="r") + except ValueError: + print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.") + input_ = imageio.imread(input_path) + elif isinstance(input_path, str): + input_ = open_file(input_path, "r")[input_key] + else: with zarr.open(input_path, mode="r") as f: input_ = f[input_key] - chunks = input_.chunks() - else: - input_ = open_file(input_path, "r")[input_key] - chunks = (64, 64, 64) + + chunks = getattr(input_, "chunks", (64,64,64)) if scale is None or scale == 1: original_shape = None @@ -157,16 +161,19 @@ def find_mask(input_path, input_key, output_folder, s3=None): return if input_key is None: - raw = imageio.imread(input_path) - chunks = (64, 64, 64) - elif s3 is not None: - with zarr.open(input_path, mode="r") as fin: - raw = fin[input_key] - chunks = raw.chunks - else: + try: + raw = tifffile.memmap(input_path, mode="r") + except ValueError: + print(f"Could not memmap the data from {input_path}. Fall back to load it into memory.") + raw = imageio.imread(input_path) + elif isinstance(input_path, str): fin = open_file(input_path, "r") raw = fin[input_key] - chunks = (64, 64, 64) + else: + with zarr.open(input_path, mode="r") as fin: + raw = fin[input_key] + + chunks = getattr(raw, "chunks", (64,64,64)) block_shape = tuple(2 * ch for ch in chunks) blocking = nt.blocking([0, 0, 0], raw.shape, block_shape) @@ -318,9 +325,7 @@ def run_unet_prediction_preprocess_slurm( and stored in a JSON file within the output folder as mean_std.json. """ if s3 is not None: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - - input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) if not os.path.isdir(os.path.join(output_folder, "mask.zarr")): find_mask(input_path, input_key, output_folder, s3=s3) @@ -355,9 +360,7 @@ def run_unet_prediction_slurm( slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") if s3 is not None: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - - input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) if slurm_task_id is not None: slurm_task_id = int(slurm_task_id) diff --git a/scripts/extract_block.py b/scripts/extract_block.py index d8fc376..3f5ab3a 100644 --- a/scripts/extract_block.py +++ b/scripts/extract_block.py @@ -66,9 +66,7 @@ def main( roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - - s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(input_file, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) with zarr.open(s3_path, mode="r") as f: raw = f[input_key][roi] diff --git a/scripts/prediction/count_cells.py b/scripts/prediction/count_cells.py index 8bfef8d..087dd79 100644 --- a/scripts/prediction/count_cells.py +++ b/scripts/prediction/count_cells.py @@ -28,9 +28,7 @@ def main(): raise ValueError("Either provide an output_folder containing 'segmentation.zarr' or an S3 input.") if args.s3_input is not None: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) - - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) with zarr.open(s3_path, mode="r") as f: dataset = f[args.input_key] diff --git a/scripts/prediction/expand_seg_table.py b/scripts/prediction/expand_seg_table.py index 94c15ff..dc080fd 100644 --- a/scripts/prediction/expand_seg_table.py +++ b/scripts/prediction/expand_seg_table.py @@ -19,8 +19,7 @@ def main( :param str s3_service_endpoint: S3 service endpoint. Optional if SERVICE_ENDPOINT has been exported """ if s3: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(s3_bucket_name, s3_service_endpoint, s3_credentials) - tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + tsv_path, fs = s3_utils.get_s3_path(in_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t") else: diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index ea9f73d..04c05da 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -37,14 +37,12 @@ def main(): tsv_table=None if args.s3_input is not None: - bucket_name, service_endpoint, credentials = s3_utils.check_s3_credentials(args.s3_bucket_name, args.s3_service_endpoint, args.s3_credentials) - - s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + s3_path, fs = s3_utils.get_s3_path(args.s3_input, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) with zarr.open(s3_path, mode="r") as f: segmentation = f[args.input_key] if args.tsv is not None: - tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=bucket_name, service_endpoint=service_endpoint, credential_file=credentials) + tsv_path, fs = s3_utils.get_s3_path(args.tsv, bucket_name=args.s3_bucket_name, service_endpoint=args.s3_service_endpoint, credential_file=args.s3_credentials) with fs.open(tsv_path, 'r') as f: tsv_table = pd.read_csv(f, sep="\t")