diff --git a/scripts/export_lower_resolution.py b/scripts/export_lower_resolution.py index f60de4b..5e2d334 100644 --- a/scripts/export_lower_resolution.py +++ b/scripts/export_lower_resolution.py @@ -1,5 +1,6 @@ import argparse import os +import warnings import numpy as np import pandas as pd @@ -7,7 +8,6 @@ import zarr from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT -# from skimage.segmentation import relabel_sequential def filter_component(fs, segmentation, cochlea, seg_name, components): @@ -19,10 +19,12 @@ def filter_component(fs, segmentation, cochlea, seg_name, components): # Then we get the ids for the components and us them to filter the segmentation. component_mask = np.isin(table.component_labels.values, components) keep_label_ids = table.label_id.values[component_mask].astype("int64") + if max(keep_label_ids) > np.iinfo("uint16").max: + warnings.warn(f"Label ID exceeds maximum of data type 'uint16': {np.iinfo('uint16').max}.") + filter_mask = ~np.isin(segmentation, keep_label_ids) segmentation[filter_mask] = 0 - - # segmentation, _, _ = relabel_sequential(segmentation) + segmentation = segmentation.astype("uint16") return segmentation @@ -41,7 +43,7 @@ def export_lower_resolution(args): s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT) with zarr.open(s3_store, mode="r") as f: data = f[input_key][:] - print(data.shape) + if args.filter_by_components is not None: data = filter_component(fs, data, args.cochlea, channel, args.filter_by_components) if args.binarize: diff --git a/scripts/export_synapse_detections.py b/scripts/export_synapse_detections.py index 52fee04..a555577 100644 --- a/scripts/export_synapse_detections.py +++ b/scripts/export_synapse_detections.py @@ -12,7 +12,7 @@ from tqdm import tqdm -def export_synapse_detections(cochlea, scale, output_folder, synapse_name, reference_ihcs, max_dist, radius): +def export_synapse_detections(cochlea, scale, output_folder, synapse_name, reference_ihcs, max_dist, radius, id_offset): s3 = create_s3_target() content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") @@ -53,14 +53,18 @@ def export_synapse_detections(cochlea, scale, output_folder, synapse_name, refer coordinates /= (2 ** scale) coordinates = np.round(coordinates, 0).astype("int") + ihc_ids = syn_table["matched_ihc"].values + # Create the output. output = np.zeros(shape, dtype="uint16") mask = ball(radius).astype(bool) - for coord in tqdm(coordinates, desc="Writing synapses to volume"): + for coord, matched_ihc in tqdm( + zip(coordinates, ihc_ids), total=len(coordinates), desc="Writing synapses to volume" + ): bb = tuple(slice(c - radius, c + radius + 1) for c in coord) try: - output[bb][mask] = 1 + output[bb][mask] = matched_ihc + id_offset except IndexError: print("Index error for", coord) continue @@ -68,7 +72,10 @@ def export_synapse_detections(cochlea, scale, output_folder, synapse_name, refer # Write the output. out_folder = os.path.join(output_folder, cochlea, f"scale{scale}") os.makedirs(out_folder, exist_ok=True) - out_path = os.path.join(out_folder, f"{synapse_name}.tif") + if id_offset != 0: + out_path = os.path.join(out_folder, f"{synapse_name}_offset{id_offset}.tif") + else: + out_path = os.path.join(out_folder, f"{synapse_name}.tif") print("Writing synapses to", out_path) tifffile.imwrite(out_path, output, bigtiff=True, compression="zlib") @@ -78,16 +85,18 @@ def main(): parser.add_argument("--cochlea", "-c", required=True) parser.add_argument("--scale", "-s", type=int, required=True) parser.add_argument("--output_folder", "-o", required=True) - parser.add_argument("--synapse_name", default="synapse_v3_ihc_v4") - parser.add_argument("--reference_ihcs", default="IHC_v4") + parser.add_argument("--synapse_name", default="synapse_v3_ihc_v4b") + parser.add_argument("--reference_ihcs", default="IHC_v4b") parser.add_argument("--max_dist", type=float, default=3.0) parser.add_argument("--radius", type=int, default=3) + parser.add_argument("--id_offset", type=int, default=0) args = parser.parse_args() export_synapse_detections( args.cochlea, args.scale, args.output_folder, args.synapse_name, args.reference_ihcs, - args.max_dist, args.radius + args.max_dist, args.radius, + args.id_offset, )