diff --git a/flamingo_tools/extract_block_util.py b/flamingo_tools/extract_block_util.py index 09851da..73fb775 100644 --- a/flamingo_tools/extract_block_util.py +++ b/flamingo_tools/extract_block_util.py @@ -1,9 +1,10 @@ import os -from typing import Optional, List +from typing import Optional, List, Union, Tuple import imageio.v3 as imageio import numpy as np import zarr +from skimage.transform import rescale import flamingo_tools.s3_utils as s3_utils from flamingo_tools.file_utils import read_image_data @@ -15,14 +16,15 @@ def extract_block( output_dir: Optional[str] = None, input_key: Optional[str] = None, output_key: Optional[str] = None, - resolution: float = 0.38, + resolution: Union[float, Tuple[float, float, float]] = 0.38, roi_halo: List[int] = [128, 128, 64], tif: bool = False, s3: Optional[bool] = False, s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, -): + scale_factor: Optional[Tuple[float, float, float]] = None, +) -> None: """Extract block around coordinate from input data according to a given halo. Either from a local file or from an S3 bucket. @@ -38,11 +40,14 @@ def extract_block( s3_bucket_name: S3 bucket name. s3_service_endpoint: S3 service endpoint. s3_credentials: File path to credentials for S3 bucket. + scale_factor: Optional factor for rescaling the extracted data. """ - coords = [int(round(c)) for c in coords] - coord_string = "-".join([str(c).zfill(4) for c in coords]) + coord_string = "-".join([str(int(round(c))).zfill(4) for c in coords]) # Dimensions are inversed to view in MoBIE (x y z) -> (z y x) + # Make sure the coords / roi_halo are not modified in-place. + coords = coords.copy() + roi_halo = roi_halo.copy() coords.reverse() roi_halo.reverse() @@ -61,6 +66,7 @@ def extract_block( if output_dir == "": output_dir = input_dir + os.makedirs(output_dir, exist_ok=True) if tif: if output_key is None: @@ -73,21 +79,32 @@ def extract_block( output_key = "raw" if output_key is None else output_key output_file = os.path.join(output_dir, basename + "_crop_" + coord_string + ".n5") - coords = np.array(coords) + coords = np.array(coords).astype("float") + if not isinstance(resolution, float): + assert len(resolution) == 3 + resolution = np.array(resolution)[::-1] coords = coords / resolution coords = np.round(coords).astype(np.int32) roi = tuple(slice(co - rh, co + rh) for co, rh in zip(coords, roi_halo)) if s3: - input_path, fs = s3_utils.get_s3_path(input_path, bucket_name=s3_bucket_name, - service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + input_path, fs = s3_utils.get_s3_path( + input_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials + ) data_ = read_image_data(input_path, input_key) data_roi = data_[roi] + if scale_factor is not None: + kwargs = {"preserve_range": True} + # Check if this is a segmentation. + if data_roi.dtype in (np.dtype("int32"), np.dtype("uint32"), np.dtype("int64"), np.dtype("uint64")): + kwargs.update({"order": 0, "anti_aliasing": False}) + data_roi = rescale(data_roi, scale_factor, **kwargs).astype(data_roi.dtype) if tif: imageio.imwrite(output_file, data_roi, compression="zlib") else: - with zarr.open(output_file, mode="w") as f_out: - f_out.create_dataset(output_key, data=data_roi, compression="gzip") + f_out = zarr.open(output_file, mode="w") + f_out.create_dataset(output_key, data=data_roi, compression="gzip") diff --git a/flamingo_tools/file_utils.py b/flamingo_tools/file_utils.py index acfcdca..7015aa9 100644 --- a/flamingo_tools/file_utils.py +++ b/flamingo_tools/file_utils.py @@ -85,6 +85,6 @@ def read_image_data(input_path: Union[str, Store], input_key: Optional[str]) -> elif isinstance(input_path, str): input_ = open_file(input_path, "r")[input_key] else: - with zarr.open(input_path, mode="r") as f: - input_ = f[input_key] + f = zarr.open(input_path, mode="r") + input_ = f[input_key] return input_ diff --git a/flamingo_tools/s3_utils.py b/flamingo_tools/s3_utils.py index ec5c111..a3c59ec 100644 --- a/flamingo_tools/s3_utils.py +++ b/flamingo_tools/s3_utils.py @@ -1,7 +1,11 @@ """This file contains utility functions for processing data located on an S3 storage. The upload of data to the storage system should be performed with 'rclone'. """ +import json import os +import warnings +from shutil import which +from subprocess import run from typing import Optional, Tuple import s3fs @@ -115,14 +119,23 @@ def get_s3_path( bucket_name, service_endpoint, credential_file ) - s3_filesystem = create_s3_target(url=service_endpoint, anon=False, credential_file=credential_file) + zarr_major_version = int(zarr.__version__.split(".")[0]) + s3_filesystem = create_s3_target( + url=service_endpoint, anon=False, credential_file=credential_file, asynchronous=zarr_major_version == 3, + ) zarr_path = f"{bucket_name}/{input_path}" - if not s3_filesystem.exists(zarr_path): + if zarr_major_version == 2 and not s3_filesystem.exists(zarr_path): print(f"Error: S3 path {zarr_path} does not exist!") - s3_path = zarr.storage.FSStore(zarr_path, fs=s3_filesystem) + # The approach for opening a dataset from S3 differs in zarr v2 and zarr v3. + if zarr_major_version == 2: + s3_path = zarr.storage.FSStore(zarr_path, fs=s3_filesystem) + elif zarr_major_version == 3: + s3_path = zarr.storage.FsspecStore(fs=s3_filesystem, path=zarr_path) + else: + raise RuntimeError(f"Unsupported zarr version {zarr_major_version}") return s3_path, s3_filesystem @@ -153,6 +166,7 @@ def create_s3_target( url: Optional[str] = None, anon: Optional[str] = False, credential_file: Optional[str] = None, + asynchronous: bool = False, ) -> s3fs.core.S3FileSystem: """Create file system for S3 bucket based on a service endpoint and an optional credential file. If the credential file is not provided, the s3fs.S3FileSystem function checks the environment variables @@ -162,6 +176,7 @@ def create_s3_target( url: Service endpoint for S3 bucket anon: Option for anon argument of S3FileSystem credential_file: File path to credentials + asynchronous: Whether to open the file system in async mode. Returns: s3_filesystem @@ -169,7 +184,102 @@ def create_s3_target( client_kwargs = {"endpoint_url": SERVICE_ENDPOINT if url is None else url} if credential_file is not None: key, secret = read_s3_credentials(credential_file) - s3_filesystem = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs=client_kwargs) + s3_filesystem = s3fs.S3FileSystem( + key=key, secret=secret, client_kwargs=client_kwargs, asynchronous=asynchronous + ) else: - s3_filesystem = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs) + s3_filesystem = s3fs.S3FileSystem(anon=anon, client_kwargs=client_kwargs, asynchronous=asynchronous) return s3_filesystem + + +def _sync_rclone(local_dir, target): + # The rclone alias could also be exposed as parameter. + rclone_alias = "cochlea-lightsheet" + print("Sync", local_dir, "to", target) + run(["rclone", "--progress", "copyto", local_dir, f"{rclone_alias}:{target}"]) + + +def sync_dataset( + mobie_root: str, + dataset_name: str, + bucket_name: Optional[str] = None, + url: Optional[str] = None, + anon: Optional[str] = False, + credential_file: Optional[str] = None, + force_segmentation_update: bool = False, +) -> None: + """Sync a MoBIE dataset on the s3 bucket using rclone. + + Args: + mobie_root: The directory with the local mobie project. + dataset_name: The mobie dataset to sync. + bucket_name: The name of the dataset's bucket on s3. + url: Service endpoint for S3 bucket + anon: Option for anon argument of S3FileSystem + credential_file: File path to credentials + force_segmentation_update: Whether to force segmentation updates. + """ + from mobie.metadata import add_remote_project_metadata + + # Make sure that rclone is loaded. + if which("rclone") is None: + raise RuntimeError("rclone is required for synchronization. Try loading it via 'module load rclone'.") + + # Make sure the dataset is in the local version of the dataset. + with open(os.path.join(mobie_root, "project.json")) as f: + project_metadata = json.load(f) + datasets = project_metadata["datasets"] + assert dataset_name in datasets + + # Get s3 filsystem and bucket name. + s3 = create_s3_target(url, anon, credential_file) + if bucket_name is None: + bucket_name = BUCKET_NAME + if url is None: + url = SERVICE_ENDPOINT + + # Add the required remote metadata to the project. Suppress warnings about missing local data. + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + add_remote_project_metadata(mobie_root, bucket_name, url) + + # Get the metadata from the S3 bucket. + project_metadata_path = os.path.join(bucket_name, "project.json") + with s3.open(project_metadata_path, "r") as f: + project_metadata = json.load(f) + + # Check if the dataset is part of the remote project already. + local_ds_root = os.path.join(mobie_root, dataset_name) + remote_ds_root = os.path.join(bucket_name, dataset_name) + if dataset_name not in project_metadata["datasets"]: + print("The dataset is not yet synced. Will copy it over.") + _sync_rclone(os.path.join(mobie_root, "project.json"), project_metadata_path) + _sync_rclone(local_ds_root, remote_ds_root) + return + + # Otherwise, check which sources are new and add them. + with open(os.path.join(mobie_root, dataset_name, "dataset.json")) as f: + local_dataset_metadata = json.load(f) + + dataset_metadata_path = os.path.join(bucket_name, dataset_name, "dataset.json") + with s3.open(dataset_metadata_path, "r") as f: + remote_dataset_metadata = json.load(f) + + for source_name, source_data in local_dataset_metadata["sources"].items(): + source_type, source_data = next(iter(source_data.items())) + is_segmentation = source_type == "segmentation" + is_spots = source_type == "spots" + data_path = source_data["imageData"]["ome.zarr"]["relativePath"] + source_not_on_remote = (source_name not in remote_dataset_metadata["sources"]) + # Only update the image data if the source is not updated or if we force updates for segmentations. + if source_not_on_remote or (is_segmentation and force_segmentation_update): + _sync_rclone(os.path.join(local_ds_root, data_path), os.path.join(remote_ds_root, data_path)) + # We always sync the tables. + if is_segmentation or is_spots: + table_path = source_data["tableData"]["tsv"]["relativePath"] + _sync_rclone(os.path.join(local_ds_root, table_path), os.path.join(remote_ds_root, table_path)) + + # Sync the dataset metadata. + _sync_rclone( + os.path.join(mobie_root, dataset_name, "dataset.json"), os.path.join(remote_ds_root, "dataset.json") + ) diff --git a/reproducibility/label_components/IHC_gerbil.json b/reproducibility/label_components/IHC_gerbil.json new file mode 100644 index 0000000..cdc2d03 --- /dev/null +++ b/reproducibility/label_components/IHC_gerbil.json @@ -0,0 +1,8 @@ +[ + { + "cochlea": "G_EK_000233_L", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v5" + } +] diff --git a/reproducibility/label_components/IHC_v4c_fig2.json b/reproducibility/label_components/IHC_v4c_fig2.json index 8d40000..f0bd22d 100644 --- a/reproducibility/label_components/IHC_v4c_fig2.json +++ b/reproducibility/label_components/IHC_v4c_fig2.json @@ -22,5 +22,11 @@ "image_channel": "VGlut3", "cell_type": "ihc", "unet_version": "v4c" + }, + { + "cochlea": "G_EK_000233_L", + "image_channel": "VGlut3", + "cell_type": "ihc", + "unet_version": "v5" } ] diff --git a/reproducibility/label_components/repro_label_components.py b/reproducibility/label_components/repro_label_components.py index d5054a7..730d412 100644 --- a/reproducibility/label_components/repro_label_components.py +++ b/reproducibility/label_components/repro_label_components.py @@ -6,6 +6,7 @@ import pandas as pd from flamingo_tools.s3_utils import get_s3_path from flamingo_tools.segmentation.postprocessing import label_components_sgn, label_components_ihc +from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping def repro_label_components( @@ -14,6 +15,7 @@ def repro_label_components( s3_credentials: Optional[str] = None, s3_bucket_name: Optional[str] = None, s3_service_endpoint: Optional[str] = None, + apply_tonotopic_mapping: bool = False, ): min_size = 1000 default_threshold_erode = None @@ -23,7 +25,7 @@ def repro_label_components( default_cell_type = "sgn" default_component_list = [1] - with open(ddict, 'r') as myfile: + with open(ddict, "r") as myfile: data = myfile.read() param_dicts = json.loads(data) @@ -39,11 +41,17 @@ def repro_label_components( cell_type = dic["cell_type"] if "cell_type" in dic else default_cell_type component_list = dic["component_list"] if "component_list" in dic else default_component_list + # The table name sometimes has to be over-written. + # table_name = "PV_SGN_V2_DA" + # table_name = "CR_SGN_v2" + # table_name = "Ntng1_SGN_v2" + table_name = f"{cell_type.upper()}_{unet_version}" + s3_path = os.path.join(f"{cochlea}", "tables", table_name, "default.tsv") tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, service_endpoint=s3_service_endpoint, credential_file=s3_credentials) - with fs.open(tsv_path, 'r') as f: + with fs.open(tsv_path, "r") as f: table = pd.read_csv(f, sep="\t") if cell_type == "sgn": @@ -66,8 +74,12 @@ def repro_label_components( else: print(f"Custom component(s) have {largest_comp} {cell_type.upper()}s.") + if apply_tonotopic_mapping: + tsv_table = tonotopic_mapping(tsv_table, cell_type=cell_type) + cochlea_str = "-".join(cochlea.split("_")) table_str = "-".join(table_name.split("_")) + os.makedirs(output_dir, exist_ok=True) out_path = os.path.join(output_dir, "_".join([cochlea_str, f"{table_str}.tsv"])) tsv_table.to_csv(out_path, sep="\t", index=False) @@ -77,8 +89,9 @@ def main(): parser = argparse.ArgumentParser( description="Script to label segmentation using a segmentation table and graph connected components.") - parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") - parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") + parser.add_argument("-i", "--input", type=str, required=True, help="Input JSON dictionary.") + parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.") + parser.add_argument("-t", "--tonotopic_mapping", action="store_true", help="Also compute the tonotopic mapping.") parser.add_argument("--s3_credentials", type=str, default=None, help="Input file containing S3 credentials. " @@ -93,6 +106,7 @@ def main(): repro_label_components( args.input, args.output, args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + apply_tonotopic_mapping=args.tonotopic_mapping, ) diff --git a/reproducibility/templates_processing/REAMDE.md b/reproducibility/templates_processing/REAMDE.md index 2de26aa..84d60d3 100644 --- a/reproducibility/templates_processing/REAMDE.md +++ b/reproducibility/templates_processing/REAMDE.md @@ -12,7 +12,17 @@ For IHC segmentation run: - apply_unet_IHC_template.sbatch - segment_unet_IHC_template.sbatch +After this, run the following to add segmentation to MoBIE, create component labelings and upload to S3: +- templates_transfer/mobie_segmentation_template.sbatch +- templates_transfer/sync_mobie.py +- label_components/repro_label_components.py +- templates_transfer/sync_mobie.py + For ribbon synapse detection without associated IHC segmentation run - detect_synapse_template.sbatch For ribbon synapse detection with associated IHC segmentation run - detect_synapse_marker_template.sbatch + +After this, run the following to add detections to MoBIE and upload to S3: +- templates_transfer/mobie_spots_template.sbatch +- templates_transfer/sync_mobie.py diff --git a/reproducibility/templates_processing/apply_unet_SGN_template.sbatch b/reproducibility/templates_processing/apply_unet_SGN_template.sbatch index fb3d1a9..46d6fbc 100644 --- a/reproducibility/templates_processing/apply_unet_SGN_template.sbatch +++ b/reproducibility/templates_processing/apply_unet_SGN_template.sbatch @@ -4,12 +4,13 @@ #SBATCH -p grete:shared # the partition #SBATCH -G A100:1 # For requesting 1 A100 GPU. -#SBATCH -c 1 -#SBATCH --mem 24G +#SBATCH -c 4 +#SBATCH --mem 32G #SBATCH -a 0-9 source ~/.bashrc -micromamba activate micro-sam_gpu +# micromamba activate micro-sam_gpu +micromamba activate sam # Print out some info. echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" @@ -19,7 +20,8 @@ echo "Current node: ${SLURM_NODELIST}" # Run the script -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit export SCRIPT_DIR=$SCRIPT_REPO/scripts @@ -37,9 +39,16 @@ export INPUT=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/"$COC export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$SEG_NAME" +# The default v2 model export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/v2_cochlea_distance_unet_SGN_supervised_2025-05-27 + +# Domain adapted model for MLR99L +# export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/v2_domain_adaptation_mlr99l/best.pt + export PREDICTION_INSTANCES=10 + export INPUT_KEY="setup$STAIN_CHANNEL/timepoint0/s0" +# export INPUT_KEY="s0" echo "Input directory: ${INPUT}" echo "Output directory: ${OUTPUT_FOLDER}" diff --git a/reproducibility/templates_processing/detect_synapse_marker_template.sbatch b/reproducibility/templates_processing/detect_synapse_marker_template.sbatch index eb97e09..0826fee 100644 --- a/reproducibility/templates_processing/detect_synapse_marker_template.sbatch +++ b/reproducibility/templates_processing/detect_synapse_marker_template.sbatch @@ -7,11 +7,12 @@ #SBATCH -p grete:shared # the partition #SBATCH -G A100:1 # For requesting 1 A100 GPU. #SBATCH -A nim00007 -#SBATCH -c 2 -#SBATCH --mem 36G +#SBATCH -c 8 +#SBATCH --mem 128G source ~/.bashrc -micromamba activate micro-sam_gpu +# micromamba activate micro-sam_gpu +micromamba activate sam # Print out some info. echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" @@ -19,10 +20,8 @@ echo "Home directory: ${HOME}" echo "Working directory: $PWD" echo "Current node: ${SLURM_NODELIST}" -# Run the script -#python myprogram.py $SLURM_ARRAY_TASK_ID - -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit export SCRIPT_DIR=$SCRIPT_REPO/scripts @@ -46,7 +45,8 @@ export MAX_DISTANCE=8 echo "OUTPUT_FOLDER $OUTPUT_FOLDER" echo "MODEL $MODEL" -python ~/flamingo-tools/scripts/synapse_marker_detection/marker_detection.py \ +SCRIPT="$SCRIPT_DIR"/synapse_marker_detection/marker_detection.py +python $SCRIPT \ --input "$INPUT_PATH" \ --input_key $INPUT_KEY \ --output_folder "$OUTPUT_FOLDER" \ @@ -54,4 +54,3 @@ python ~/flamingo-tools/scripts/synapse_marker_detection/marker_detection.py \ --model $MODEL \ --max_distance $MAX_DISTANCE \ --s3 - diff --git a/reproducibility/templates_processing/detect_synapse_template.sbatch b/reproducibility/templates_processing/detect_synapse_template.sbatch index 4cc292b..888c86f 100644 --- a/reproducibility/templates_processing/detect_synapse_template.sbatch +++ b/reproducibility/templates_processing/detect_synapse_template.sbatch @@ -1,17 +1,18 @@ #!/bin/bash #SBATCH --job-name=synapse-detect -#SBATCH -t 42:00:00 # estimated time, adapt to your needs +#SBATCH -t 12:00:00 # estimated time, adapt to your needs #SBATCH --mail-user=martin.schilling@med.uni-goettingen.de # change this to your mailaddress #SBATCH --mail-type=FAIL # send mail when job begins and ends #SBATCH -p grete:shared # the partition #SBATCH -G A100:1 # For requesting 1 A100 GPU. #SBATCH -A nim00007 -#SBATCH -c 2 -#SBATCH --mem 500G +#SBATCH -c 8 +#SBATCH --mem 128G source ~/.bashrc -micromamba activate micro-sam_gpu +# micromamba activate micro-sam_gpu +micromamba activate sam # Print out some info. echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" @@ -22,7 +23,8 @@ echo "Current node: ${SLURM_NODELIST}" # Run the script #python myprogram.py $SLURM_ARRAY_TASK_ID -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit export SCRIPT_DIR=$SCRIPT_REPO/scripts @@ -31,15 +33,13 @@ export SCRIPT_DIR=$SCRIPT_REPO/scripts COCHLEA=$1 # image channel, e.g. CTBP2 or RibA IMAGE_CHANNEL=$2 -# segmentation name, as it appears in MoBIE, e.g. synapses_v3 -IHC_SEG=$3 export INPUT_PATH="$COCHLEA"/images/ome-zarr/"$IMAGE_CHANNEL".ome.zarr -export MASK_PATH="$COCHLEA"/images/ome-zarr/"$IHC_SEG".ome.zarr # data on NHR # export INPUT_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/"$COCHLEA"/"$DATA" # export INPUT_KEY="setup$STAIN_CHANNEL/timepoint0/s0" +INPUT_KEY="s0" export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/synapses_v3 @@ -52,10 +52,10 @@ export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/train echo "OUTPUT_FOLDER $OUTPUT_FOLDER" echo "MODEL $MODEL" -python ~/flamingo-tools/scripts/synapse_marker_detection/run_prediction.py \ +SCRIPT="$SCRIPT_DIR"/synapse_marker_detection/run_prediction.py +python $SCRIPT \ --input "$INPUT_PATH" \ --input_key "$INPUT_KEY" \ --output_folder "$OUTPUT_FOLDER" \ --model $MODEL \ --s3 - diff --git a/reproducibility/templates_processing/mean_std_SGN_template.sbatch b/reproducibility/templates_processing/mean_std_SGN_template.sbatch index f9e3ee9..bed241a 100644 --- a/reproducibility/templates_processing/mean_std_SGN_template.sbatch +++ b/reproducibility/templates_processing/mean_std_SGN_template.sbatch @@ -8,11 +8,13 @@ #SBATCH --mem 128G source ~/.bashrc -micromamba activate flamingo13 +# micromamba activate flamingo13 +micromamba activate sam # Run the script -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit export SCRIPT_DIR=$SCRIPT_REPO/scripts @@ -27,9 +29,12 @@ STAIN_CHANNEL=$3 SEG_NAME=$4 export INPUT=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/"$COCHLEA"/"$DATA" + export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$SEG_NAME" export SEG_CLASS="sgn" + export INPUT_KEY="setup$STAIN_CHANNEL/timepoint0/s0" +# export INPUT_KEY="s0" if ! [[ -f $OUTPUT_FOLDER ]] ; then mkdir -p "$OUTPUT_FOLDER" @@ -46,4 +51,3 @@ cmd_array=( 'import sys,os;' 'output_folder=os.environ["OUTPUT_FOLDER"],seg_class=os.environ["SEG_CLASS"])') cmd="${cmd_array[*]}" python -c "$cmd" - diff --git a/reproducibility/templates_processing/segment_unet_IHC_template.sbatch b/reproducibility/templates_processing/segment_unet_IHC_template.sbatch index 8f37afb..71fa364 100644 --- a/reproducibility/templates_processing/segment_unet_IHC_template.sbatch +++ b/reproducibility/templates_processing/segment_unet_IHC_template.sbatch @@ -8,7 +8,8 @@ #SBATCH --mem 400G source ~/.bashrc -micromamba activate micro-sam_gpu +# micromamba activate micro-sam_gpu +micromamba activate sam # Print out some info. echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" @@ -19,7 +20,8 @@ echo "Current node: ${SLURM_NODELIST}" # Run the script #python myprogram.py $SLURM_ARRAY_TASK_ID -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit # name of cochlea, as it appears in MoBIE and the NHR diff --git a/reproducibility/templates_processing/segment_unet_SGN_template.sbatch b/reproducibility/templates_processing/segment_unet_SGN_template.sbatch index 8ee9371..331e3d5 100644 --- a/reproducibility/templates_processing/segment_unet_SGN_template.sbatch +++ b/reproducibility/templates_processing/segment_unet_SGN_template.sbatch @@ -8,7 +8,8 @@ #SBATCH --mem 400G source ~/.bashrc -micromamba activate micro-sam_gpu +# micromamba activate micro-sam_gpu +micromamba activate sam # Print out some info. echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}" @@ -16,10 +17,8 @@ echo "Home directory: ${HOME}" echo "Working directory: $PWD" echo "Current node: ${SLURM_NODELIST}" -# Run the script -#python myprogram.py $SLURM_ARRAY_TASK_ID - -SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools +SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit # name of cochlea, as it appears in MoBIE and the NHR diff --git a/reproducibility/templates_transfer/mobie_image_template.sbatch b/reproducibility/templates_transfer/mobie_image_template.sbatch index 22334d4..a0bb98b 100644 --- a/reproducibility/templates_transfer/mobie_image_template.sbatch +++ b/reproducibility/templates_transfer/mobie_image_template.sbatch @@ -1,6 +1,6 @@ #!/bin/bash #SBATCH --job-name=mobie_image -#SBATCH -t 01:00:00 # estimated time, adapt to your needs +#SBATCH -t 12:00:00 # estimated time, adapt to your needs #SBATCH --mail-user=martin.schilling@med.uni-goettingen.de # change this to your mailaddress #SBATCH --mail-type=FAIL # send mail when job begins and ends @@ -10,8 +10,8 @@ #SBATCH --mem 180G source ~/.bashrc -source ~/miniconda3/bin/activate -source activate mobie +# source activate sam +micromamba activate sam # Run the script diff --git a/reproducibility/templates_transfer/mobie_segmentation_template.sbatch b/reproducibility/templates_transfer/mobie_segmentation_template.sbatch old mode 100644 new mode 100755 index eb8bbf3..cc17a3f --- a/reproducibility/templates_transfer/mobie_segmentation_template.sbatch +++ b/reproducibility/templates_transfer/mobie_segmentation_template.sbatch @@ -1,6 +1,6 @@ #!/bin/bash #SBATCH --job-name=mobie_segm -#SBATCH -t 01:00:00 # estimated time, adapt to your needs +#SBATCH -t 02:00:00 # estimated time, adapt to your needs #SBATCH --mail-user=martin.schilling@med.uni-goettingen.de # change this to your mailaddress #SBATCH --mail-type=FAIL # send mail when job begins and ends @@ -10,20 +10,21 @@ #SBATCH --mem 180G source ~/.bashrc -source ~/miniconda3/bin/activate -source activate mobie +# source activate mobie +micromamba activate sam # Run the script # name of cochlea, as it appears in MoBIE and the NHR COCHLEA=$1 -# data in n5 format, e.g. GEK11L_PV_GFP_01_fused.n5 +# data in n5 format, e.g. IHC_v5/segmentation.zarr DATA=$2 -# segmentation name, as it appears in MoBIE, e.g. PV or Calb1 +# segmentation name, as it appears in MoBIE, e.g. SGN_V2 CHANNEL_NAME=$3 MOBIE_PROJECT="/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" -INPUT_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/"$COCHLEA"/"$DATA" +# /mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/G_EK_000233_L/IHC_v5/segmentation.zarr/ +INPUT_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$DATA" SEGMENTATION_KEY="segmentation" @@ -31,6 +32,9 @@ RESOLUTION="[0.38,0.38,0.38]" SCALE_FACTORS="[[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2],[2,2,2]]" CHUNKS="[64,64,64]" -mobie.add_segmentation --input_path "$SEGMENTATION_PATH" --input_key "$SEGMENTATION_KEY" --root "$MOBIE_PROJECT" \ +echo $INPUT_PATH +echo $MOBIE_PROJECT + +mobie.add_segmentation --input_path "$INPUT_PATH" --input_key "$SEGMENTATION_KEY" --root "$MOBIE_PROJECT" \ --dataset_name "$COCHLEA" --name "$CHANNEL_NAME" --resolution "$RESOLUTION" \ --scale_factors "$SCALE_FACTORS" --chunks "$CHUNKS" diff --git a/reproducibility/templates_transfer/mobie_spots_template.sbatch b/reproducibility/templates_transfer/mobie_spots_template.sbatch index 2cb4fba..d9727e4 100644 --- a/reproducibility/templates_transfer/mobie_spots_template.sbatch +++ b/reproducibility/templates_transfer/mobie_spots_template.sbatch @@ -10,19 +10,22 @@ #SBATCH --mem 16G source ~/.bashrc -source ~/miniconda3/bin/activate -source activate mobie +# source ~/miniconda3/bin/activate +# source activate mobie +micromamba activate sam # Run the script # name of cochlea, as it appears in MoBIE and the NHR COCHLEA=$1 + # segmentation name, as it appears in MoBIE, e.g. synapses_v3 or synapses_v3_ihc_v4 SPOT_NAME=$2 MOBIE_PROJECT="/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" -TABLE_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$SPOT_NAME"/synapse_detection.tsv # synapse_detection_filtered.tsv +# TABLE_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$SPOT_NAME"/synapse_detection.tsv +TABLE_PATH=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/"$SPOT_NAME"/synapse_detection_filtered.tsv mobie.add_spots --input_table "$TABLE_PATH" --root "$MOBIE_PROJECT" \ --dataset_name "$COCHLEA" --name "$SPOT_NAME" diff --git a/reproducibility/templates_transfer/s3_cochlea_template.sh b/reproducibility/templates_transfer/s3_cochlea_template.sh index 30e4c13..eb22db0 100644 --- a/reproducibility/templates_transfer/s3_cochlea_template.sh +++ b/reproducibility/templates_transfer/s3_cochlea_template.sh @@ -9,3 +9,4 @@ export SERVICE_ENDPOINT="https://s3.fs.gwdg.de" mobie.add_remote_metadata -i $MOBIE_DIR -s $SERVICE_ENDPOINT -b $BUCKET_NAME rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA" cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA" +rclone --progress copyto "$MOBIE_DIR"/project.json cochlea-lightsheet:cochlea-lightsheet/project.json diff --git a/reproducibility/templates_transfer/s3_seg_template.sh b/reproducibility/templates_transfer/s3_seg_template.sh old mode 100644 new mode 100755 index 2b4f100..51c10ba --- a/reproducibility/templates_transfer/s3_seg_template.sh +++ b/reproducibility/templates_transfer/s3_seg_template.sh @@ -12,7 +12,8 @@ export SERVICE_ENDPOINT="https://s3.fs.gwdg.de" mobie.add_remote_metadata -i $MOBIE_DIR -s $SERVICE_ENDPOINT -b $BUCKET_NAME rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA"/dataset.json cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA"/dataset.json -rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA"/images/ome-zarr cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA"/images/ome-zarr +rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA"/images/ome-zarr/"$SEG_CHANNEL".ome.zarr cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA"/images/ome-zarr/"$SEG_CHANNEL".ome.zarr +# TODO enable to also sync the whole thing and project.json # take care that segmentation tables containing evaluations (tonotopic mapping, marker labels, etc.) might be overwritten rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA"/tables/"$SEG_CHANNEL" cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA"/tables/"$SEG_CHANNEL" diff --git a/reproducibility/templates_transfer/s3_synapse_template.sh b/reproducibility/templates_transfer/s3_synapse_template.sh old mode 100644 new mode 100755 diff --git a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py index 8ce6a6c..b6e9e3d 100644 --- a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py +++ b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py @@ -42,6 +42,7 @@ def repro_tonotopic_mapping( cochlea_str = "-".join(cochlea.split("_")) seg_str = "-".join(seg_channel.split("_")) + os.makedirs(output_dir, exist_ok=True) output_table_path = os.path.join(output_dir, f"{cochlea_str}_{seg_str}.tsv") s3_path = os.path.join(f"{cochlea}", "tables", f"{seg_channel}", "default.tsv") diff --git a/scripts/export_frequency_mapping.py b/scripts/export_frequency_mapping.py index 8177906..9f2bbf2 100644 --- a/scripts/export_frequency_mapping.py +++ b/scripts/export_frequency_mapping.py @@ -36,8 +36,8 @@ def export_frequency_mapping(cochlea, scale, output_folder, source_name, colorma seg_path = os.path.join(cochlea, source["imageData"]["ome.zarr"]["relativePath"]) s3_store, _ = get_s3_path(seg_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT) input_key = f"s{scale}" - with zarr.open(s3_store, mode="r") as f: - seg = f[input_key][:] + f = zarr.open(s3_store, mode="r") + seg = f[input_key][:] mapping = {int(seg_id): freq for seg_id, freq in zip(table.label_id, frequencies)} lut = np.zeros(max_id + 1, dtype="float32") diff --git a/scripts/export_lower_resolution.py b/scripts/export_lower_resolution.py index 67ff4c5..a3946e9 100644 --- a/scripts/export_lower_resolution.py +++ b/scripts/export_lower_resolution.py @@ -153,7 +153,7 @@ def export_lower_resolution(args): internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{channel}.ome.zarr") s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT) with zarr.open(s3_store, mode="r") as f: - data = f[input_key][:] + data = f[input_key][:].astype("float32") print("Data shape", data.shape) if args.filter_by_components is not None: print(f"Filtering channel {channel} by components {args.filter_by_components}.") diff --git a/scripts/figures/plot_fig2.py b/scripts/figures/plot_fig2.py index a897090..9f1541b 100644 --- a/scripts/figures/plot_fig2.py +++ b/scripts/figures/plot_fig2.py @@ -8,7 +8,7 @@ from matplotlib import colors from skimage.segmentation import find_boundaries -from util import literature_reference_values +from util import literature_reference_values, SYNAPSE_DIR_ROOT png_dpi = 300 @@ -66,7 +66,7 @@ def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, bou plt.close() -def fig_02b_sgn(save_dir, plot=False): +def fig_02a_sgn(save_dir, plot=False): """Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam. """ cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" @@ -170,10 +170,10 @@ def fig_02c(save_path, plot=False, all_versions=False): # Plot plt.figure(figsize=(8, 5)) - main_label_size = 20 + main_label_size = 22 sub_label_size = 16 - main_tick_size = 12 - legendsize = 16 + main_tick_size = 16 + legendsize = 18 plt.scatter(x - offset, precision, label="Precision", marker="o", s=80) plt.scatter(x, recall, label="Recall", marker="^", s=80) @@ -204,7 +204,7 @@ def fig_02c(save_path, plot=False, all_versions=False): # Load the synapse counts for all IHCs from the relevant tables. def _load_ribbon_synapse_counts(): ihc_version = "ihc_counts_v4c" - synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" + synapse_dir = os.path.join(SYNAPSE_DIR_ROOT, ihc_version) tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name] syn_counts = [] for tab in tables: @@ -216,8 +216,8 @@ def _load_ribbon_synapse_counts(): def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_synapses=False): """Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values. """ - main_tick_size = 16 - main_label_size = 24 + main_tick_size = 20 + main_label_size = 26 rows = 1 columns = 3 if plot_average_ribbon_synapses else 2 @@ -321,7 +321,7 @@ def fig_02d_02(save_path, filter_zeros=True, plot=False): """ cochleae = ["M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R"] ihc_version = "ihc_counts_v4b" - synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" + synapse_dir = os.path.join(SYNAPSE_DIR_ROOT, ihc_version) max_dist = 3 bins = 10 @@ -402,9 +402,12 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) - # Panel C: Evaluation of the segmentation results: - fig_02b_sgn(save_dir=args.figure_dir, plot=args.plot) + # Panes A and B: Qualitative comparison of visualization results. + fig_02a_sgn(save_dir=args.figure_dir, plot=args.plot) + return fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot) + + # Panel C: Evaluation of the segmentation results: fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False) # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC diff --git a/scripts/figures/plot_fig3.py b/scripts/figures/plot_fig3.py index 0e8e8c8..51939d7 100644 --- a/scripts/figures/plot_fig3.py +++ b/scripts/figures/plot_fig3.py @@ -1,24 +1,39 @@ import argparse import os import imageio.v3 as imageio +from glob import glob +from pathlib import Path +import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib import cm, colors -from util import sliding_runlength_sum, frequency_mapping +from util import sliding_runlength_sum, frequency_mapping, SYNAPSE_DIR_ROOT, to_alias -INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3/frequency_mapping" +INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3" -png_dpi = 300 +TYPE_TO_CHANNEL = { + "Type-Ia": "CR", + "Type-Ib": "Calb1", + "Type-Ic": "Lypd1", + "Type-II": "Prph", +} +png_dpi = 300 -def fig_03a(save_path): - import napari - path = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif") - vol = imageio.imread(path) +def _plot_colormap(vol, title, plot, save_path): + # before creating the figure: + matplotlib.rcParams.update({ + "font.size": 14, # base font size + "axes.titlesize": 18, # for plt.title / ax.set_title + "figure.titlesize": 18, # for fig.suptitle (if you use it) + "xtick.labelsize": 14, + "ytick.labelsize": 14, + "legend.fontsize": 14, + }) # Create the colormap fig, ax = plt.subplots(figsize=(6, 1.3)) @@ -31,43 +46,29 @@ def fig_03a(save_path): cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal") cb.set_label("Frequency [kHz]") - plt.title("Tonotopic Mapping: IHCs") + plt.title(title) plt.tight_layout() - out_path = os.path.join(save_path) - plt.savefig(out_path) - - # Show the image in napari for rendering. - v = napari.Viewer() - v.add_image(vol, colormap="viridis") - napari.run() - - -def fig_03b(save_path): - import napari + if plot: + plt.show() - path = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif") - vol = imageio.imread(path) + plt.savefig(save_path) + plt.close() - # Create the colormap - fig, ax = plt.subplots(figsize=(6, 1.3)) - fig.subplots_adjust(bottom=0.5) - freq_min = np.min(np.nonzero(vol)) - freq_max = vol.max() - norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) - cmap = plt.get_cmap("viridis") - - cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal") - cb.set_label("Frequency [kHz]") - plt.title("Tonotopic Mapping: SGNs") - plt.tight_layout() - out_path = os.path.join(save_path) - plt.savefig(out_path) +def fig_03a(save_path, plot, plot_napari): + path_ihc = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif") + path_sgn = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif") + sgn = imageio.imread(path_sgn) + ihc = imageio.imread(path_ihc) + _plot_colormap(sgn, title="Tonotopic Mapping", plot=plot, save_path=save_path) # Show the image in napari for rendering. - v = napari.Viewer() - v.add_image(vol, colormap="viridis") - napari.run() + if plot_napari: + import napari + v = napari.Viewer() + v.add_image(ihc, colormap="viridis") + v.add_image(sgn, colormap="viridis") + napari.run() def fig_03c_rl(save_path, plot=False): @@ -105,18 +106,17 @@ def fig_03c_rl(save_path, plot=False): def fig_03c_octave(save_path, plot=False): ihc_version = "ihc_counts_v4c" - synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" - tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name] + tables = glob(os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_LR*.tsv")) + assert len(tables) == 4, len(tables) result = {"cochlea": [], "octave_band": [], "value": []} for tab_path in tables: - # TODO map to alias - alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "") + cochlea = Path(tab_path).stem.lstrip("ihc_count") + alias = to_alias(cochlea) tab = pd.read_csv(tab_path, sep="\t") freq = tab["frequency"].values syn_count = tab["synapse_count"].values - # Compute the running sum of 10 micron. octave_binned = frequency_mapping(freq, syn_count, animal="mouse") result["cochlea"].extend([alias] * len(octave_binned)) @@ -147,6 +147,69 @@ def fig_03c_octave(save_path, plot=False): plt.close() +def fig_03d_fraction(save_path, plot): + result_folder = "../measurements/subtype_analysis" + files = glob(os.path.join(result_folder, "*.tsv")) + + # FIXME + analysis = { + "M_AMD_N62_L": ["CR", "Calb1"], + "M_LR_000214_L": ["CR"], + } + + results = {"type": [], "fraction": [], "cochlea": []} + for ff in files: + fname = os.path.basename(ff) + cochlea = fname[:-len("_subtype_analysis.tsv")] + + if cochlea not in analysis: + continue + + table = pd.read_csv(ff, sep="\t") + + subtype_table = table[[col for col in table.columns if col.startswith("is_")]] + assert subtype_table.shape[1] == 2 + n_sgns = len(subtype_table) + + print(cochlea) + for col in subtype_table.columns: + vals = table[col].values + subtype = col[3:] + channel = TYPE_TO_CHANNEL[subtype] + if channel not in analysis[cochlea]: + continue + n_subtype = vals.sum() + subtype_fraction = np.round(float(n_subtype) / n_sgns * 100, 2) + name = f"{subtype} ({channel})" + print("{name}:", n_subtype, "/", n_sgns, f"({subtype_fraction} %)") + + results["type"].append(name) + results["fraction"].append(subtype_fraction) + results["cochlea"].append(cochlea) + + # coexpr = np.logical_and(subtype_table.iloc[:, 0].values, subtype_table.iloc[:, 1].values) + # print("Co-expression:", coexpr.sum()) + + results = pd.DataFrame(results) + fig, ax = plt.subplots() + for cochlea, group in results.groupby("cochlea"): + ax.scatter(group["type"], group["fraction"], label=cochlea) + ax.set_ylabel("Fraction") + ax.set_xlabel("Type") + ax.legend(title="Cochlea ID") + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + if plot: + plt.show() + else: + plt.close() + + +# TODO +def fig_03d_octave(save_path, plot): + pass + + def main(): parser = argparse.ArgumentParser(description="Generate plots for Fig 3 of the cochlea paper.") parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig3") @@ -155,18 +218,17 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) - # Panel A: Tonotopic mapping of IHCs (rendering in napari) - # fig_03a(save_path=os.path.join(args.figure_dir, "fig_03a.png")) - - # Panel B: Tonotopic mapping of SGNs (rendering in napari) - # fig_03b(save_path=os.path.join(args.figure_dir, "fig_03b.png")) + # Panel A: Tonotopic mapping of SGNs and IHCs (rendering in napari + heatmap) + fig_03a(save_path=os.path.join(args.figure_dir, "fig_03a_cmap.png"), plot=args.plot, plot_napari=True) # Panel C: Spatial distribution of synapses across the cochlea. # We have two options: running sum over the runlength or per octave band fig_03c_rl(save_path=os.path.join(args.figure_dir, "fig_03c_runlength.png"), plot=args.plot) fig_03c_octave(save_path=os.path.join(args.figure_dir, "fig_03c_octave.png"), plot=args.plot) - # TODO: Panel D: Spatial distribution of SGN sub-types. + # Panel D: Spatial distribution of SGN sub-types. + fig_03d_fraction(save_path=os.path.join(args.figure_dir, "fig_03d_fraction.png"), plot=args.plot) + fig_03d_octave(save_path=os.path.join(args.figure_dir, "fig_03d_octave.png"), plot=args.plot) if __name__ == "__main__": diff --git a/scripts/figures/plot_fig5.py b/scripts/figures/plot_fig5.py index b7a61cb..eed6e2c 100644 --- a/scripts/figures/plot_fig5.py +++ b/scripts/figures/plot_fig5.py @@ -1,29 +1,34 @@ import argparse +import json import os import numpy as np import pandas as pd import matplotlib.pyplot as plt -from util import literature_reference_values +from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target + +from util import SYNAPSE_DIR_ROOT +from plot_fig4 import get_chreef_data png_dpi = 300 def _load_ribbon_synapse_counts(): + # TODO update the version! ihc_version = "ihc_counts_v4b" - synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" - tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_AMD" in entry.name] - syn_counts = [] - for tab in tables: - x = pd.read_csv(tab, sep="\t") - syn_counts.extend(x["synapse_count"].values.tolist()) + table_path = os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_AMD_OTOF1_L.tsv") + x = pd.read_csv(table_path, sep="\t") + syn_counts = x.synapse_count.values.tolist() return syn_counts def fig_05c(save_path, plot=False): - """Bar plot showing the distribution of synapse markers per IHC segmentation average over OTOF cochlea. + """Bar plot showing the IHC count and distribution of synapse markers per IHC segmentation over OTOF cochlea. """ + # TODO update the alias. + # For MOTOF1L + alias = "M10L" main_label_size = 20 main_tick_size = 12 @@ -31,43 +36,79 @@ def fig_05c(save_path, plot=False): ribbon_synapse_counts = _load_ribbon_synapse_counts() - fig, ax = plt.subplots(figsize=(8, 4)) + rows, columns = 1, 2 + fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4)) + + # + # Create the plot for IHCs. + # + ihc_values = [len(ribbon_synapse_counts)] + + ylim0 = 600 + ylim1 = 800 + y_ticks = [i for i in range(600, 800 + 1, 100)] + axes[0].set_ylabel("IHC count", fontsize=main_label_size) + axes[0].set_yticks(y_ticks) + axes[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + axes[0].set_ylim(ylim0, ylim1) + + axes[0].boxplot(ihc_values) + axes[0].set_xticklabels([alias], fontsize=main_label_size) + + # Set the reference values for healthy cochleae + xmin = 0.5 + xmax = 1.5 + ihc_reference_values = [712, 710, 721, 675] # MLR226L, MLR226R, MLR227L, MLR227R + + ihc_value = np.mean(ihc_reference_values) + ihc_std = np.std(ihc_reference_values) + + upper_y = ihc_value + 1.96 * ihc_std + lower_y = ihc_value - 1.96 * ihc_std + + axes[0].hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) + axes[0].text(1, upper_y + 10, "healthy cochleae", color="C1", fontsize=main_tick_size, ha="center") + axes[0].fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + + # + # Create the plot for ribbon synapse distribution. + # ylim0 = -1 ylim1 = 24 y_ticks = [i for i in range(0, 25, 5)] - ax.set_ylabel("Ribbon Syn. per IHC", fontsize=main_label_size) - ax.set_yticks(y_ticks) - ax.set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax.set_ylim(ylim0, ylim1) + axes[1].set_ylabel("Ribbon Syn. per IHC", fontsize=main_label_size) + axes[1].set_yticks(y_ticks) + axes[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + axes[1].set_ylim(ylim0, ylim1) + + axes[1].boxplot(ribbon_synapse_counts) + axes[1].set_xticklabels([alias], fontsize=main_label_size) - ax.boxplot(ribbon_synapse_counts) - ax.set_xticklabels(["MOTOF1L"], fontsize=main_label_size) + axes[1].yaxis.tick_right() + axes[1].yaxis.set_ticks_position("right") + axes[1].yaxis.set_label_position("right") - # set range of literature values + # Set the reference values for healthy cochleae xmin = 0.5 xmax = 1.5 - lower_y, upper_y = literature_reference_values("synapse") - ax.set_xlim(xmin, xmax) - ax.hlines([lower_y, upper_y], xmin, xmax) - ax.text(1.25, upper_y + 0.01*ax.get_ylim()[1]-ax.get_ylim()[0], "literature", - color="C0", fontsize=htext_size, ha="center") - ax.fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + syn_reference_values = [14.1, 12.7, 13.8, 13.4] # MLR226L, MLR226R, MLR227L, MLR227R - ihc_values = [14.1, 12.7, 13.8, 13.4] # MLR226L, MLR226R, MLR227L, MLR227R + syn_value = np.mean(syn_reference_values) + syn_std = np.std(syn_reference_values) - ihc_value = np.mean(ihc_values) - ihc_std = np.std(ihc_values) - - upper_y = ihc_value + 1.96 * ihc_std - lower_y = ihc_value - 1.96 * ihc_std + upper_y = syn_value + 1.96 * syn_std + lower_y = syn_value - 1.96 * syn_std plt.hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) - plt.text(1.25, upper_y + 0.01*ax.get_ylim()[1]-ax.get_ylim()[0], "healthy cochleae (95% confidence interval)", - color="C1", fontsize=htext_size, ha="center") + plt.text( + 1.25, upper_y + 0.01*axes[1].get_ylim()[1]-axes[1].get_ylim()[0], "healthy cochleae", + color="C1", fontsize=htext_size, ha="center" + ) plt.fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + # Save and plot the figure. plt.tight_layout() plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) @@ -77,6 +118,33 @@ def fig_05c(save_path, plot=False): plt.close() +# TODO +def fig_05d(save_path, plot): + if False: + s3 = create_s3_target() + + # Intensity distribution for OTOF + cochlea = "M_AMD_OTOF1_L" + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the seg table and filter the compartments. + source_name = "IHC_v4c" + source = sources[source_name]["segmentation"] + rel_path = source["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + print(table) + + # TODO would need the new intensity subtracted data here. + # Reference: intensity distributions for ChReef + chreef_data = get_chreef_data() + for cochlea, tab in chreef_data.items(): + plt.hist(tab["median"]) + plt.show() + + def main(): parser = argparse.ArgumentParser(description="Generate plots for Fig 5 of the cochlea paper.") parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig5") @@ -86,7 +154,10 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) # Panel C: Monitoring of the Syn / IHC loss - fig_05c(save_path=os.path.join(args.figure_dir, "fig_05c"), plot=args.plot) + # fig_05c(save_path=os.path.join(args.figure_dir, "fig_05c"), plot=args.plot) + + # Panel D: Tonotopic mapping of the intensities. + fig_05d(save_path=os.path.join(args.figure_dir, "fig_05d"), plot=args.plot) if __name__ == "__main__": diff --git a/scripts/figures/plot_fig6.py b/scripts/figures/plot_fig6.py index 2c0a2c7..0553739 100644 --- a/scripts/figures/plot_fig6.py +++ b/scripts/figures/plot_fig6.py @@ -8,7 +8,7 @@ png_dpi = 300 -def fig_06a(save_path, plot=False): +def fig_06b(save_path, plot=False): """Box plot showing the counts for SGN and IHC per gerbil cochlea in comparison to literature values. """ main_tick_size = 12 @@ -76,13 +76,29 @@ def fig_06a(save_path, plot=False): plt.close() +def fig_06d(save_path, plot=False): + """Plot the synapse distribution measured with different markers. + + The underlying measurements were done with 'scripts/measurements/synapse_colocalization.py' + + Here are the other relevant numbers for the analysis. + Number of IHCs: 486 + Number of matched synapses: 3119 + Number and percentage of matched synapses for markers: + CTBP2: 3119 / 3371 (92.52447345001484% matched) + RibA : 3119 / 6701 (46.54529174750037% matched) + """ + # TODO Plot this + + def main(): - parser = argparse.ArgumentParser(description="Generate plots for Fig 2 of the cochlea paper.") + parser = argparse.ArgumentParser(description="Generate plots for Fig 6 of the cochlea paper.") parser.add_argument("figure_dir", type=str, help="Output directory for plots.", default="./panels") args = parser.parse_args() plot = False - fig_06a(save_path=os.path.join(args.figure_dir, "fig_06a"), plot=plot) + fig_06b(save_path=os.path.join(args.figure_dir, "fig_06b"), plot=plot) + fig_06d(save_path=os.path.join(args.figure_dir, "fig_06d"), plot=plot) if __name__ == "__main__": diff --git a/scripts/figures/segmentation_comparison.py b/scripts/figures/segmentation_comparison.py new file mode 100644 index 0000000..1b0932e --- /dev/null +++ b/scripts/figures/segmentation_comparison.py @@ -0,0 +1,126 @@ +import os +from glob import glob +from pathlib import Path +import json + +import imageio.v3 as imageio +import napari +import numpy as np +from skimage.segmentation import find_boundaries + +FOR_COMPARISON = ["distance_unet", "micro-sam", "cellpose3"] + + +def _eval_seg(seg, eval_path): + with open(eval_path, "r") as f: + eval_res = json.load(f) + + correct, wrong = eval_res["tp_objects"], eval_res["fp"] + all_ids = correct + wrong + seg[~np.isin(seg, all_ids)] = 0 + + eva_mask = np.zeros_like(seg) + + eva_mask[np.isin(seg, correct)] = 1 + eva_mask[np.isin(seg, wrong)] = 2 + + bd = find_boundaries(seg) + return bd, eva_mask + + +def sgn_comparison(): + scale = (0.38,) * 2 + z = 10 + + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation" + + image_paths = sorted(glob(os.path.join(image_dir, "*.tif"))) + + for path in image_paths: + image = imageio.imread(path)[z] + + seg_fname = Path(path).stem + "_seg.tif" + eval_fname = Path(path).stem + "_dic.json" + + segmentations, boundaries, eval_im = {}, {}, {} + for seg_name in FOR_COMPARISON: + seg_path = os.path.join(val_sgn_dir, seg_name, seg_fname) + eval_path = os.path.join(val_sgn_dir, seg_name, eval_fname) + assert os.path.exists(seg_path), seg_path + + seg = imageio.imread(seg_path)[z] + + bd, eva = _eval_seg(seg, eval_path) + segmentations[seg_name] = seg + boundaries[seg_name] = bd + eval_im[seg_name] = eva + + v = napari.Viewer() + v.add_image(image, scale=scale) + for seg_name, bd in boundaries.items(): + v.add_labels(bd, name=seg_name, colormap={1: "cyan"}, scale=scale) + v.add_labels(eval_im[seg_name], name=f"{seg_name}_eval", colormap={1: "green", 2: "red"}, scale=scale) + + v.scale_bar.visible = True + v.scale_bar.unit = "μm" + v.scale_bar.font_size = 16 + v.title = Path(path).stem + + napari.run() + + +def ihc_comparison(): + scale = (0.38,) * 2 + z = 10 + + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs" + + image_paths = sorted(glob(os.path.join(image_dir, "*.tif"))) + scale = (0.38,) * 2 + + for path in image_paths: + image = imageio.imread(path)[z] + + seg_fname = Path(path).stem + "_seg.tif" + eval_fname = Path(path).stem + "_dic.json" + + segmentations, boundaries, eval_im = {}, {}, {} + for seg_name in FOR_COMPARISON: + # FIXME distance_unet_v4b is missing the eval files + seg_name_ = "distance_unet_v3" if seg_name == "distance_unet" else seg_name + seg_path = os.path.join(val_sgn_dir, seg_name_, seg_fname) + eval_path = os.path.join(val_sgn_dir, seg_name_, eval_fname) + assert os.path.exists(seg_path), seg_path + + seg = imageio.imread(seg_path)[z] + + bd, eva = _eval_seg(seg, eval_path) + segmentations[seg_name] = seg + boundaries[seg_name] = bd + eval_im[seg_name] = eva + + v = napari.Viewer() + v.add_image(image, scale=scale) + for seg_name, bd in boundaries.items(): + v.add_labels(bd, name=seg_name, colormap={1: "cyan"}, scale=scale) + v.add_labels(eval_im[seg_name], name=f"{seg_name}_eval", colormap={1: "green", 2: "red"}, scale=scale) + + v.scale_bar.visible = True + v.scale_bar.unit = "μm" + v.scale_bar.font_size = 16 + v.title = Path(path).stem + + napari.run() + + +def main(): + # sgn_comparison() + ihc_comparison() + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/supp_fig2.py b/scripts/figures/supp_fig2.py new file mode 100644 index 0000000..3f21d30 --- /dev/null +++ b/scripts/figures/supp_fig2.py @@ -0,0 +1,70 @@ +import json +import os + +from glob import glob +from pathlib import Path + +import numpy as np + + +# FIXME something is off with cellpose-sam runtimes +def runtimes_sgn(): + for_comparison = ["distance_unet", "micro-sam", "cellpose3", "cellpose-sam", "stardist"] + + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation" + + image_paths = sorted(glob(os.path.join(image_dir, "*.tif"))) + + runtimes = {name: [] for name in for_comparison} + + for path in image_paths: + eval_fname = Path(path).stem + "_dic.json" + for seg_name in for_comparison: + eval_path = os.path.join(val_sgn_dir, seg_name, eval_fname) + with open(eval_path, "r") as f: + result = json.load(f) + rt = result["time"] + runtimes[seg_name].append(rt) + + for name, rts in runtimes.items(): + print(name, ":", np.mean(rts), "+-", np.std(rts)) + + +def runtimes_ihc(): + for_comparison = ["distance_unet_v3", "micro-sam", "cellpose3", "cellpose-sam"] + + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs" + + image_paths = sorted(glob(os.path.join(image_dir, "*.tif"))) + + runtimes = {name: [] for name in for_comparison} + + for path in image_paths: + eval_fname = Path(path).stem + "_dic.json" + for seg_name in for_comparison: + eval_path = os.path.join(val_sgn_dir, seg_name, eval_fname) + if not os.path.exists(eval_path): + continue + with open(eval_path, "r") as f: + result = json.load(f) + rt = result["time"] + runtimes[seg_name].append(rt) + + for name, rts in runtimes.items(): + print(name, ":", np.mean(rts), "+-", np.std(rts)) + + +def main(): + print("SGNs:") + runtimes_sgn() + print() + print("IHCs:") + runtimes_ihc() + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/util.py b/scripts/figures/util.py index 8cf6501..489fa95 100644 --- a/scripts/figures/util.py +++ b/scripts/figures/util.py @@ -1,6 +1,10 @@ import pandas as pd import numpy as np +# Directory with synapse measurement tables +SYNAPSE_DIR_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses" +# SYNAPSE_DIR_ROOT = "./synapses" + # Define the animal specific octave bands. def _get_mapping(animal): @@ -60,6 +64,7 @@ def sliding_runlength_sum(run_length, values, width): return x, window_sum +# For mouse def literature_reference_values(structure): if structure == "SGN": lower_bound, upper_bound = 9141, 11736 @@ -72,6 +77,7 @@ def literature_reference_values(structure): return lower_bound, upper_bound +# For gerbil def literature_reference_values_gerbil(structure): if structure == "SGN": lower_bound, upper_bound = 24700, 28450 @@ -81,4 +87,15 @@ def literature_reference_values_gerbil(structure): lower_bound, upper_bound = 9.1, 20.7 else: raise ValueError - return lower_bound, upper_bound \ No newline at end of file + return lower_bound, upper_bound + + +def to_alias(cochlea_name): + name_short = cochlea_name.replace("_", "").replace("0", "") + name_to_alias = { + "MLR226L": "M01L", + "MLR226R": "M01R", + "MLR227L": "M02L", + "MLR227R": "M02R", + } + return name_to_alias[name_short] diff --git a/scripts/la-vision/analyze_la_vision.py b/scripts/la-vision/analyze_la_vision.py new file mode 100644 index 0000000..c4c6d1e --- /dev/null +++ b/scripts/la-vision/analyze_la_vision.py @@ -0,0 +1,75 @@ +import json +import os + +import numpy as np +import pandas as pd +from flamingo_tools.s3_utils import create_s3_target, BUCKET_NAME, get_s3_path + + +# Note: downsampling with anisotropic scale in the beginning would make sense for better visualization. +def analyze_sgn(visualize=False): + s3 = create_s3_target() + datasets = ["LaVision-M04", "LaVision-Mar05"] + + # Use this to select the compoents for analysis. + sgn_components = { + "LaVision-M04": [1], + "LaVision-Mar05": [1], + } + seg_name = "SGN_LOWRES-v2" + + for cochlea in datasets: + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the segmentation table. + seg_source = sources[seg_name] + table_folder = os.path.join( + BUCKET_NAME, cochlea, seg_source["segmentation"]["tableData"]["tsv"]["relativePath"] + ) + table_content = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + + if visualize: + import napari + import zarr + from nifty.tools import takeDict + + key = "s2" + img_s3 = f"{cochlea}/images/ome-zarr/PV.ome.zarr" + seg_s3 = os.path.join(cochlea, seg_source["segmentation"]["imageData"]["ome.zarr"]["relativePath"]) + img_path, _ = get_s3_path(img_s3) + seg_path, _ = get_s3_path(seg_s3) + + print("Loading image data") + f = zarr.open(seg_path, mode="r") + seg = f[key][:] + + seg_ids = np.unique(seg) + component_dict = {int(label_id): int(component_id) + for label_id, component_id in zip(table.label_id, table.component_labels)} + missing_ids = np.setdiff1d(seg_ids, table.label_id.values) + component_dict.update({miss: 0 for miss in missing_ids}) + components = takeDict(component_dict, seg) + + f = zarr.open(img_path, mode="r") + data = f[key][:] + + v = napari.Viewer() + v.add_image(data) + v.add_labels(seg) + v.add_labels(components) + napari.run() + + table = table[table.component_labels.isin(sgn_components[cochlea])] + n_sgns = len(table) + print(cochlea, ":", n_sgns) + + +def main(): + analyze_sgn(visualize=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/import_marmoset_lsm.py b/scripts/la-vision/import_marmoset_lsm.py new file mode 100644 index 0000000..e93dece --- /dev/null +++ b/scripts/la-vision/import_marmoset_lsm.py @@ -0,0 +1,36 @@ +import os + +import imageio.v3 as imageio +from mobie import add_image + +INPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/cochlea-lightsheet/keppeler-et-al/marmoset" +MOBIE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" +DS_NAME = "LaVision-Mar05" +RESOLUTION = (3.0, 1.887779, 1.887779) + + +# Marmoset_cochlea_05_LSFM_ch1_raw.tif +def add_marmoset_05(): + channel_names = ("PV", "7-AAD", "MYO") + + scale_factors = 4 * [[2, 2, 2]] + chunks = (96, 96, 96) + + for channel_id, channel_name in enumerate(channel_names, 1): + input_path = os.path.join(INPUT_ROOT, f"Marmoset_cochlea_05_LSFM_ch{channel_id}_raw.tif") + print("Load image data ...") + input_data = imageio.imread(input_path) + print(input_data.shape) + add_image( + input_path=input_data, input_key=None, root=MOBIE_ROOT, + dataset_name=DS_NAME, image_name=channel_name, resolution=RESOLUTION, + scale_factors=scale_factors, chunks=chunks, unit="micrometer", use_memmap=False, + ) + + +def main(): + add_marmoset_05() + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/import_mouse_lsm.py b/scripts/la-vision/import_mouse_lsm.py new file mode 100644 index 0000000..ae98b44 --- /dev/null +++ b/scripts/la-vision/import_mouse_lsm.py @@ -0,0 +1,43 @@ +import os + +import imageio.v3 as imageio +from mobie import add_image +from mobie.metadata import read_dataset_metadata + +INPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/cochlea-lightsheet/keppeler-et-al/mouse" +MOBIE_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" +DS_NAME = "LaVision-M04" +RESOLUTION = (3.0, 1.887779, 1.887779) + + +# Mouse_cochlea_04_LSFM_ch1_raw.tif Mouse_cochlea_04_LSFM_ch2_raw.tif Mouse_cochlea_04_LSFM_ch3_raw.tif +def add_mouse_lsm(): + channel_names = ("PV", "7-AAD", "MYO") + + scale_factors = 4 * [[2, 2, 2]] + chunks = (96, 96, 96) + + for channel_id, channel_name in enumerate(channel_names, 1): + mobie_ds_folder = os.path.join(MOBIE_ROOT, DS_NAME) + ds_metadata = read_dataset_metadata(mobie_ds_folder) + if channel_name in ds_metadata["sources"]: + print(channel_name, "is already in MoBIE") + continue + + input_path = os.path.join(INPUT_ROOT, f"Mouse_cochlea_04_LSFM_ch{channel_id}_raw.tif") + print("Load image data ...") + input_data = imageio.imread(input_path) + print(input_data.shape) + add_image( + input_path=input_data, input_key=None, root=MOBIE_ROOT, + dataset_name=DS_NAME, image_name=channel_name, resolution=RESOLUTION, + scale_factors=scale_factors, chunks=chunks, unit="micrometer", use_memmap=False, + ) + + +def main(): + add_mouse_lsm() + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/predict_blocks.py b/scripts/la-vision/predict_blocks.py new file mode 100644 index 0000000..d4316be --- /dev/null +++ b/scripts/la-vision/predict_blocks.py @@ -0,0 +1,98 @@ +import argparse +import os +from glob import glob + +import imageio.v3 as imageio +import numpy as np + +from skimage.segmentation import watershed +from skimage.measure import label +from torch_em.util import load_model +from torch_em.util.prediction import predict_with_halo + + +def _get_files(sgn=True): + input_root = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_M04" # noqa + input_root2 = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_M04_2" # noqa + input_root3 = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_Mar05" # noqa + + input_root_ihc = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/more-annotations/LA_VISION_Mar05-ihc" # noqa + + if sgn: + input_files = glob(os.path.join(input_root, "*.tif")) +\ + glob(os.path.join(input_root2, "*.tif")) +\ + glob(os.path.join(input_root3, "*.tif")) + else: + input_files = glob(os.path.join(input_root_ihc, "*.tif")) + + return input_files + + +def predict_blocks(model_path, name, sgn=True): + output_folder = os.path.join("./predictions", name) + os.makedirs(output_folder, exist_ok=True) + + input_blocks = _get_files(sgn) + + model = None + for path in input_blocks: + out_path = os.path.join(output_folder, os.path.basename(path)) + if os.path.exists(out_path): + continue + if model is None: + model = load_model(model_path) + data = imageio.imread(path) + pred = predict_with_halo(data, model, gpu_ids=[0], block_shape=[64, 128, 128], halo=[8, 32, 32]) + imageio.imwrite(out_path, pred, compression="zlib") + + +def _segment_impl(pred, dist_threshold=0.5): + fg, center_dist, boundary_dist = pred + mask = fg > 0.5 + + seeds = label(np.logical_and(center_dist < dist_threshold, boundary_dist < dist_threshold)) + seg = watershed(boundary_dist, mask=mask, markers=seeds) + + return seg + + +def check_segmentation(name, sgn): + import napari + + input_files = _get_files(sgn) + + output_folder = os.path.join("./predictions", name) + + for path in input_files: + image = imageio.imread(path) + pred_path = os.path.join(output_folder, os.path.basename(path)) + pred = imageio.imread(pred_path) + if sgn: + seg = _segment_impl(pred) + else: + seg = label(pred[0] > 0.5) + v = napari.Viewer() + v.add_image(image) + v.add_image(pred) + v.add_labels(seg) + v.title = os.path.basename(path) + napari.run() + + +# Model path for original training on low-res SGNs: +# /mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/training/checkpoints/cochlea_distance_unet_low-res-sgn # noqa +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--name", "-n", required=True) + parser.add_argument("--model_path", "-m") + parser.add_argument("--check", action="store_true") + parser.add_argument("--ihc", action="store_true") + args = parser.parse_args() + + predict_blocks(args.model_path, args.name, sgn=not args.ihc) + if args.check: + check_segmentation(args.name, sgn=not args.ihc) + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/segment_sgns.py b/scripts/la-vision/segment_sgns.py new file mode 100644 index 0000000..f0ec7d1 --- /dev/null +++ b/scripts/la-vision/segment_sgns.py @@ -0,0 +1,101 @@ +import os +from subprocess import run + +import pandas as pd +from flamingo_tools.segmentation import run_unet_prediction +from flamingo_tools.segmentation.postprocessing import label_components_sgn +from mobie import add_segmentation +from mobie.metadata import add_remote_project_metadata + +MODEL_PATH = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/training/checkpoints/cochlea_distance_unet_low-res-sgn-v2" # noqa +RESOLUTION = (3.0, 1.887779, 1.887779) +SEG_NAME = "SGN_LOWRES-v2" + + +def segment_sgns(input_path, input_key, output_folder): + run_unet_prediction( + input_path, input_key, output_folder, + model_path=MODEL_PATH, min_size=50, + block_shape=(64, 128, 128), halo=(8, 32, 32), + center_distance_threshold=0.5, boundary_distance_threshold=0.5, + ) + + +def add_to_mobie(output_folder, mobie_dir, dataset_name): + segmentation_path = os.path.join(output_folder, "segmentation.zarr") + segmentation_key = "segmentation" + + scale_factors = 4 * [[2, 2, 2]] + chunks = (96, 96, 96) + + add_segmentation( + segmentation_path, segmentation_key, + mobie_dir, dataset_name=dataset_name, + segmentation_name=SEG_NAME, + resolution=RESOLUTION, + scale_factors=scale_factors, chunks=chunks, + unit="micrometer", + ) + + +def compute_components(mobie_dir, dataset_name): + table_path = os.path.join(mobie_dir, dataset_name, "tables", SEG_NAME, "default.tsv") + table = pd.read_csv(table_path, sep="\t") + # This may need to be further adapted + table = label_components_sgn(table, min_size=100, + threshold_erode=None, + min_component_length=50, + max_edge_distance=30, + iterations_erode=None) + table.to_csv(table_path, sep="\t", index=False) + + +def upload_to_s3(mobie_dir, dataset_name): + service_endpoint = "https://s3.fs.gwdg.de" + bucket_name = "cochlea-lightsheet" + + add_remote_project_metadata(mobie_dir, bucket_name, service_endpoint) + + # run(["module", "load", "rclone"]) + run(["rclone", "--progress", "copyto", + f"{mobie_dir}/{dataset_name}/dataset.json", + f"cochlea-lightsheet:cochlea-lightsheet/{dataset_name}/dataset.json"]) + run(["rclone", "--progress", "copyto", + f"{mobie_dir}/{dataset_name}/images/ome-zarr", + f"cochlea-lightsheet:cochlea-lightsheet/{dataset_name}/images/ome-zarr"]) + run(["rclone", "--progress", "copyto", + f"{mobie_dir}/{dataset_name}/tables/{SEG_NAME}", + f"cochlea-lightsheet:cochlea-lightsheet/{dataset_name}/tables/{SEG_NAME}"]) + + +def segmentation_workflow(mobie_dir, output_folder, dataset_name): + input_path = os.path.join(mobie_dir, dataset_name, "images/ome-zarr/PV.ome.zarr") + input_key = "s0" + + segment_sgns(input_path, input_key, output_folder) + add_to_mobie(output_folder, mobie_dir, dataset_name) + compute_components(mobie_dir, dataset_name) + upload_to_s3(mobie_dir, dataset_name) + + +def segment_M04(): + mobie_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" # noqa + output_folder = "./segmentation/M04" + dataset_name = "LaVision-M04" + segmentation_workflow(mobie_dir, output_folder, dataset_name) + + +def segment_Mar05(): + mobie_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet" # noqa + output_folder = "./segmentation/Mar05" + dataset_name = "LaVision-Mar05" + segmentation_workflow(mobie_dir, output_folder, dataset_name) + + +def main(): + segment_M04() + segment_Mar05() + + +if __name__ == "__main__": + main() diff --git a/scripts/la-vision/upload_marmoset_lsm.sh b/scripts/la-vision/upload_marmoset_lsm.sh new file mode 100755 index 0000000..2d07096 --- /dev/null +++ b/scripts/la-vision/upload_marmoset_lsm.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet +COCHLEA=LaVision-Mar05 + +export BUCKET_NAME="cochlea-lightsheet" +export SERVICE_ENDPOINT="https://s3.fs.gwdg.de" +mobie.add_remote_metadata -i $MOBIE_DIR -s $SERVICE_ENDPOINT -b $BUCKET_NAME + +rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA" cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA" +rclone --progress copyto "$MOBIE_DIR"/project.json cochlea-lightsheet:cochlea-lightsheet/project.json diff --git a/scripts/la-vision/upload_mouse_lsm.sh b/scripts/la-vision/upload_mouse_lsm.sh new file mode 100755 index 0000000..eb696f5 --- /dev/null +++ b/scripts/la-vision/upload_mouse_lsm.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet +COCHLEA=LaVision-M04 + +export BUCKET_NAME="cochlea-lightsheet" +export SERVICE_ENDPOINT="https://s3.fs.gwdg.de" +mobie.add_remote_metadata -i $MOBIE_DIR -s $SERVICE_ENDPOINT -b $BUCKET_NAME + +rclone --progress copyto "$MOBIE_DIR"/"$COCHLEA" cochlea-lightsheet:cochlea-lightsheet/"$COCHLEA" +rclone --progress copyto "$MOBIE_DIR"/project.json cochlea-lightsheet:cochlea-lightsheet/project.json diff --git a/scripts/measurements/measure_gerbil.py b/scripts/measurements/measure_gerbil.py new file mode 100644 index 0000000..b10dde1 --- /dev/null +++ b/scripts/measurements/measure_gerbil.py @@ -0,0 +1,100 @@ +import json +import os + +import numpy as np +import pandas as pd +from flamingo_tools.s3_utils import create_s3_target, BUCKET_NAME + +COCHLEAE = ["G_EK_000233_L"] +SGN_COMPONENTS = {} +IHC_COMPONENTS = {"G_EK_000233_L": [1, 2, 3, 4, 5, 8]} + + +def open_json(fs, path): + s3_path = os.path.join(BUCKET_NAME, path) + with fs.open(s3_path, "r") as f: + content = json.load(f) + return content + + +def open_tsv(fs, path): + s3_path = os.path.join(BUCKET_NAME, path) + with fs.open(s3_path, "r") as f: + table = pd.read_csv(f, sep="\t") + return table + + +def measure_sgns(fs): + print("SGNs:") + seg_name = "SGN_v2" + for dataset in COCHLEAE: + print("Cochlea:", dataset) + dataset_info = open_json(fs, os.path.join(dataset, "dataset.json")) + sources = dataset_info["sources"] + assert seg_name in sources + + source_info = sources[seg_name]["segmentation"] + table_path = source_info["tableData"]["tsv"]["relativePath"] + table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv")) + + component_labels = table.component_labels.values + component_ids = SGN_COMPONENTS.get(dataset, [1]) + n_sgns = np.isin(component_labels, component_ids).sum() + print("N-SGNs:", n_sgns) + + +def measure_ihcs(fs): + print("IHCs:") + seg_name = "IHC_v5" + for dataset in COCHLEAE: + print("Cochlea:", dataset) + dataset_info = open_json(fs, os.path.join(dataset, "dataset.json")) + sources = dataset_info["sources"] + assert seg_name in sources + + source_info = sources[seg_name]["segmentation"] + table_path = source_info["tableData"]["tsv"]["relativePath"] + table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv")) + + component_labels = table.component_labels.values + component_ids = IHC_COMPONENTS.get(dataset, [1]) + n_ihcs = np.isin(component_labels, component_ids).sum() + print("N-IHCs:", n_ihcs) + + +def measure_synapses(fs): + print("Synapses:") + spot_name = "synapses_v3_IHC_v5" + seg_name = "IHC_v5" + for dataset in COCHLEAE: + print("Cochlea:", dataset) + dataset_info = open_json(fs, os.path.join(dataset, "dataset.json")) + sources = dataset_info["sources"] + assert spot_name in sources + + source_info = sources[spot_name]["spots"] + table_path = source_info["tableData"]["tsv"]["relativePath"] + table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv")) + + source_info = sources[seg_name]["segmentation"] + table_path = source_info["tableData"]["tsv"]["relativePath"] + ihc_table = open_tsv(fs, os.path.join(dataset, table_path, "default.tsv")) + + ihc_components = IHC_COMPONENTS.get(dataset, [1]) + valid_ihcs = ihc_table.label_id[ihc_table.component_labels.isin(ihc_components)] + table = table[table.matched_ihc.isin(valid_ihcs)] + + _, syn_count = np.unique(table.matched_ihc.values, return_counts=True) + print("Avg Syn. per IHC:") + print(np.mean(syn_count), "+-", np.std(syn_count)) + + +def main(): + fs = create_s3_target() + measure_sgns(fs) + measure_ihcs(fs) + measure_synapses(fs) + + +if __name__ == "__main__": + main() diff --git a/scripts/measurements/sgn_subtypes.py b/scripts/measurements/sgn_subtypes.py new file mode 100644 index 0000000..9449d25 --- /dev/null +++ b/scripts/measurements/sgn_subtypes.py @@ -0,0 +1,348 @@ +import json +import os +from glob import glob +from subprocess import run + +import matplotlib.pyplot as plt +import pandas as pd +from skimage.filters import threshold_otsu + +from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target, get_s3_path +from flamingo_tools.measurements import compute_object_measures + +# Map from cochlea names to channels +COCHLEAE_FOR_SUBTYPES = { + "M_LR_000099_L": ["PV", "Calb1", "Lypd1"], + "M_LR_000214_L": ["PV", "CR", "Calb1"], + "M_AMD_N62_L": ["PV", "CR", "Calb1"], + "M_AMD_N180_R": ["CR", "Ntng1", "CTBP2"], + "M_AMD_N180_L": ["CR", "Ntng1", "Lypd1"], + # Mutant / some stuff is weird. + # "M_AMD_Runx1_L": ["PV", "Lypd1", "Calb1"], + # This one still has to be stitched: + # "M_LR_000184_R": {"PV", "Prph"}, +} + +# Map from channels to subtypes. +# Comment Aleyna: +# The signal will be a gradient between different subtypes: +# For example CR is expressed more, is brigther, +# in type 1a SGNs but exist in type Ib SGNs and to a lesser extent in type 1c. +# Same is also true for other markers so we will need to set a threshold for each. +# Luckily the signal seems less variable compared to GFP. +CHANNEL_TO_TYPE = { + "CR": "Type-Ia", + "Calb1": "Type-Ib", + "Lypd1": "Type-Ic", + "Prph": "Type-II", + "Ntng1": "Type-Ib/c", +} + +# For custom thresholds. +THRESHOLDS = { + "M_LR_000214_L": { + }, + "M_AMD_N62_L": { + }, +} + +PLOT_OUT = "./subtype_plots" + + +def check_processing_status(): + s3 = create_s3_target() + + # For checking the dataset names. + # content = s3.open(f"{BUCKET_NAME}/project.json", mode="r", encoding="utf-8") + # info = json.loads(content.read()) + # datasets = info["datasets"] + # for name in datasets: + # print(name) + # breakpoint() + + missing_tables = {} + + for cochlea, channels in COCHLEAE_FOR_SUBTYPES.items(): + try: + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + except FileNotFoundError: + print(cochlea, "is not yet on MoBIE") + print() + continue + info = json.loads(content.read()) + sources = info["sources"] + + channels_found = [name for name in channels if name in sources] + channels_missing = [name for name in channels if name not in sources] + + print("Cochlea:", cochlea) + print("Found the expected channels:", channels_found) + if channels_missing: + print("Missing the expected channels:", channels_missing) + + if "SGN_v2" in sources: + print("SGN segmentation is present with name SGN_v2") + seg_name = "SGN-v2" + table_folder = "tables/SGN_v2" + elif "PV_SGN_v2" in sources: + print("SGN segmentation is present with name PV_SGN_v2") + seg_name = "PV-SGN-v2" + table_folder = "tables/PV_SGN_v2" + elif "CR_SGN_v2" in sources: + print("SGN segmentation is present with name CR_SGN_v2") + seg_name = "CR-SGN-v2" + table_folder = "tables/CR_SGN_v2" + else: + print("SGN segmentation is MISSING") + print() + continue + + # Check which tables we have. + if cochlea == "M_AMD_N180_L": # we need all intensity measures here + seg_names = ["CR-SGN-v2", "Ntng1-SGN-v2", "Lypd1-SGN-v2"] + expected_tables = [f"{chan}_{sname}_object-measures.tsv" for chan in channels for sname in seg_names] + elif cochlea == "M_AMD_N180_R": + seg_names = ["CR-SGN-v2", "Ntng1-SGN-v2"] + expected_tables = [f"{chan}_{sname}_object-measures.tsv" for chan in channels for sname in seg_names] + else: + expected_tables = [f"{chan}_{seg_name}_object-measures.tsv" for chan in channels] + + tables = s3.ls(os.path.join(BUCKET_NAME, cochlea, table_folder)) + tables = [os.path.basename(tab) for tab in tables] + + this_missing_tables = [] + for exp_tab in expected_tables: + if exp_tab not in tables: + print("Missing table:", exp_tab) + this_missing_tables.append(exp_tab) + missing_tables[cochlea] = this_missing_tables + print() + + return missing_tables + + +def require_missing_tables(missing_tables): + output_root = "./object_measurements" + + for cochlea, missing_tabs in missing_tables.items(): + for missing in missing_tabs: + channel = missing.split("_")[0] + seg_name = missing.split("_")[1].replace("-", "_") + print("Computing intensities for cochlea:", cochlea, "segmentation:", seg_name, "channel:", channel) + + img_s3 = f"{cochlea}/images/ome-zarr/{channel}.ome.zarr" + seg_s3 = f"{cochlea}/images/ome-zarr/{seg_name}.ome.zarr" + seg_table_s3 = f"{cochlea}/tables/{seg_name}/default.tsv" + img_path, _ = get_s3_path(img_s3) + seg_path, _ = get_s3_path(seg_s3) + + output_folder = os.path.join(output_root, cochlea) + os.makedirs(output_folder, exist_ok=True) + output_table_path = os.path.join( + output_folder, f"{channel}_{seg_name.replace('_', '-')}_object-measures.tsv" + ) + compute_object_measures( + image_path=img_path, + segmentation_path=seg_path, + segmentation_table_path=seg_table_s3, + output_table_path=output_table_path, + image_key="s0", + segmentation_key="s0", + s3_flag=True, + component_list=[1], + n_threads=16, + ) + + # S3 upload + run(["rclone", "--progress", "copyto", output_folder, + f"cochlea-lightsheet:cochlea-lightsheet/{cochlea}/tables/{seg_name}"]) + + +def compile_data_for_subtype_analysis(): + s3 = create_s3_target() + + output_folder = "./subtype_analysis" + os.makedirs(output_folder, exist_ok=True) + + for cochlea, channels in COCHLEAE_FOR_SUBTYPES.items(): + if "PV" in channels: + reference_channel = "PV" + seg_name = "PV_SGN_v2" + else: + assert "CR" in channels + reference_channel = "CR" + seg_name = "CR_SGN_v2" + reference_channel, seg_name + + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the segmentation table. + seg_source = sources[seg_name] + table_folder = os.path.join( + BUCKET_NAME, cochlea, seg_source["segmentation"]["tableData"]["tsv"]["relativePath"] + ) + table_content = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + + # Get the SGNs in the main component + table = table[table.component_labels == 1] + valid_sgns = table.label_id + + output_table = {"label_id": table.label_id.values, "frequency[kHz]": table["frequency[kHz]"]} + + # Analyze the different channels (= different subtypes). + reference_intensity = None + for channel in channels: + # Load the intensity table. + intensity_path = os.path.join(table_folder, f"{channel}_{seg_name.replace('_', '-')}_object-measures.tsv") + table_content = s3.open(intensity_path, mode="rb") + + intensities = pd.read_csv(table_content, sep="\t") + intensities = intensities[intensities.label_id.isin(valid_sgns)] + assert len(table) == len(intensities) + assert (intensities.label_id.values == table.label_id.values).all() + + medians = intensities["median"].values + output_table[f"{channel}_median"] = medians + if channel == reference_channel: + reference_intensity = medians + else: + assert reference_intensity is not None + output_table[f"{channel}_ratio_{reference_channel}"] = medians / reference_intensity + + out_path = os.path.join(output_folder, f"{cochlea}_subtype_analysis.tsv") + output_table = pd.DataFrame(output_table) + output_table.to_csv(out_path, sep="\t", index=False) + + +def _plot_histogram(table, column, name, show_plots, subtype=None): + data = table[column].values + threshold = threshold_otsu(data) + + fig, ax = plt.subplots(1) + ax.hist(data, bins=24) + ax.axvline(x=threshold, color='red', linestyle='--') + ax.set_title(f"{name}\n threshold: {threshold}") + + if show_plots: + plt.show() + else: + os.makedirs(PLOT_OUT, exist_ok=True) + plt.savefig(f"{PLOT_OUT}/{name}.png") + + if subtype is not None: + subtype_classification = [None if datum < threshold else subtype for datum in data] + return subtype_classification + + +def _plot_2d(ratios, name, show_plots, classification=None): + fig, ax = plt.subplots(1) + assert len(ratios) == 2 + keys = list(ratios.keys()) + k1, k2 = keys + + if classification is None: + ax.scatter(ratios[k1, k2]) + + else: + def _combine(a, b): + if a is None and b is None: + return None + elif a is None and b is not None: + return b + elif a is not None and b is None: + return a + else: + return f"{a}-{b}" + + classification = [cls for cls in classification if cls is not None] + labels = classification[0].copy() + for cls in classification[1:]: + if cls is None: + continue + labels = [_combine(a, b) for a, b in zip(labels, cls)] + + unique_labels = set(ll for ll in labels if ll is not None) + all_colors = ["red", "blue", "orange", "yellow"] + colors = {ll: color for ll, color in zip(unique_labels, all_colors[:len(unique_labels)])} + + for lbl in unique_labels: + mask = [ll == lbl for ll in labels] + ax.scatter( + [ratios[k1][i] for i in range(len(labels)) if mask[i]], + [ratios[k2][i] for i in range(len(labels)) if mask[i]], + c=colors[lbl], label=lbl + ) + + mask_none = [ll is None for ll in labels] + ax.scatter( + [ratios[k1][i] for i in range(len(labels)) if mask_none[i]], + [ratios[k2][i] for i in range(len(labels)) if mask_none[i]], + facecolors="none", edgecolors="black", label="None" + ) + + ax.legend() + + ax.set_xlabel(k1) + ax.set_ylabel(k2) + ax.set_title(name) + + if show_plots: + plt.show() + else: + os.makedirs(PLOT_OUT, exist_ok=True) + plt.savefig(f"{PLOT_OUT}/{name}.png") + + +# TODO enable over-writing by manual thresholds +def analyze_subtype_data(show_plots=True): + files = sorted(glob("./subtype_analysis/*.tsv")) + + for ff in files: + cochlea = os.path.basename(ff)[:-len("_subtype_analysis.tsv")] + print(cochlea) + channels = COCHLEAE_FOR_SUBTYPES[cochlea] + reference_channel = "PV" if "PV" in channels else "CR" + assert channels[0] == reference_channel + + tab = pd.read_csv(ff, sep="\t") + + # 1.) Plot simple intensity histograms, including otsu threshold. + for chan in channels: + column = f"{chan}_median" + name = f"{cochlea}_{chan}_histogram" + _plot_histogram(tab, column, name, show_plots) + + # 2.) Plot ratio histograms, including otsu threshold. + # TODO ratio based classification and overlay in 2d plot? + ratios = {} + subtype_classification = [] + for chan in channels[1:]: + column = f"{chan}_ratio_{reference_channel}" + name = f"{cochlea}_{chan}_histogram_ratio_{reference_channel}" + classification = _plot_histogram( + tab, column, name, subtype=CHANNEL_TO_TYPE.get(chan, None), show_plots=show_plots + ) + subtype_classification.append(classification) + ratios[f"{chan}_{reference_channel}"] = tab[column].values + + # 3.) Plot 2D space of ratios. + name = f"{cochlea}_2d" + _plot_2d(ratios, name, show_plots, classification=subtype_classification) + + +# General notes: +# See: +def main(): + missing_tables = check_processing_status() + require_missing_tables(missing_tables) + + # compile_data_for_subtype_analysis() + + # analyze_subtype_data(show_plots=False) + + +if __name__ == "__main__": + main() diff --git a/scripts/measurements/synapse_colocalization.py b/scripts/measurements/synapse_colocalization.py new file mode 100644 index 0000000..bbd5169 --- /dev/null +++ b/scripts/measurements/synapse_colocalization.py @@ -0,0 +1,155 @@ +import json +import os + +import numpy as np +import pandas as pd + +from flamingo_tools.s3_utils import BUCKET_NAME, SERVICE_ENDPOINT, create_s3_target, get_s3_path +from flamingo_tools.validation import match_detections + +COCHLEA = "M_LR_000215_R" + + +def _load_table(s3, source, valid_ihcs): + source_info = source["spots"] + rel_path = source_info["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, COCHLEA, rel_path, "default.tsv"), mode="rb") + synapse_table = pd.read_csv(table_content, sep="\t") + synapse_table_filtered = synapse_table[synapse_table.matched_ihc.isin(valid_ihcs)] + return synapse_table_filtered + + +def _save_ihc_table(table, output_name): + ihc_ids, syn_per_ihc = np.unique(table.matched_ihc.values, return_counts=True) + ihc_ids = ihc_ids.astype("int") + ihc_to_count = {ihc_id: count for ihc_id, count in zip(ihc_ids, syn_per_ihc)} + ihc_count_table = pd.DataFrame({ + "label_id": list(ihc_to_count.keys()), "synapse_count": list(ihc_to_count.values()) + }) + output_path = os.path.join("ihc_counts", f"ihc_count_{COCHLEA}_{output_name}.tsv") + ihc_count_table.to_csv(output_path, sep="\t", index=False) + + +def _run_colocalization(riba_table, ctbp2_table, max_dist=2.0): + coords_riba = riba_table[["z", "y", "x"]].values + coords_ctbp2 = ctbp2_table[["z", "y", "x"]].values + + matches_riba, matches_ctbp2, unmatched_riba, unmatched_ctbp2 = match_detections( + coords_riba, coords_ctbp2, max_dist=max_dist, + ) + assert len(matches_riba) == len(matches_ctbp2) + + # For quick visualization + if False: + matched_coords = coords_riba[matches_riba] + + import napari + v = napari.Viewer() + v.add_points(coords_riba, name="RibA", face_color="orange") + v.add_points(coords_ctbp2, name="CTBP2") + v.add_points(matched_coords, name="Coloc", face_color="green") + napari.run() + + return matches_riba, unmatched_riba, unmatched_ctbp2 + + +def _get_synapse_tables(): + name_ctbp2 = "CTBP2_synapse_v3_ihc_v4b" + name_riba = "RibA_synapse_v3_ihc_v4b" + + s3 = create_s3_target() + content = s3.open(f"{BUCKET_NAME}/{COCHLEA}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + ihc_table_path = sources["IHC_v4b"]["segmentation"]["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, COCHLEA, ihc_table_path, "default.tsv"), mode="rb") + ihc_table = pd.read_csv(table_content, sep="\t") + valid_ihcs = ihc_table.label_id[ihc_table.is_ihc == 1].values + + riba_table = _load_table(s3, sources[name_riba], valid_ihcs) + ctbp2_table = _load_table(s3, sources[name_ctbp2], valid_ihcs) + + return riba_table, ctbp2_table, valid_ihcs + + +def check_and_filter_synapses(): + riba_table, ctbp2_table, valid_ihcs = _get_synapse_tables() + + # Save the single synapse marker tables. + _save_ihc_table(riba_table, "RibA") + _save_ihc_table(ctbp2_table, "CTBP2") + + # Run co-localization, analyze it and save the table. + matches_riba, unmatched_riba, unmatched_ctbp2 = _run_colocalization(riba_table, ctbp2_table) + + n_matched = len(matches_riba) + print("Number of IHCs:", len(valid_ihcs)) + print("Number of matched synapses:", n_matched) + print() + + n_ctbp2 = n_matched + len(unmatched_ctbp2) + n_riba = n_matched + len(unmatched_riba) + print("Number and percentage of matched synapses for markers:") + print("CTBP2:", n_matched, "/", n_ctbp2, f"({float(n_matched) / n_ctbp2 * 100}% matched)") + print("RibA :", n_matched, "/", n_riba, f"({float(n_matched) / n_riba * 100}% matched)") + + coloc_table = riba_table.iloc[matches_riba] + _save_ihc_table(coloc_table, "coloc") + + +def check_synapse_predictions(): + import napari + import z5py + import zarr + + pred_path_ctbp2 = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/M_LR_000215_R/CTBP2_synapses_v3/predictions.zarr" # noqa + # pred_path_riba = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/M_LR_000215_R/RibA_synapses_v3" # noqa + + location_mobie = [836.654754589637, 1255.8010489404858, 677.1537312920972] + resolution = 0.38 + location = [int(loc / resolution) for loc in location_mobie[::-1]] + + halo = (32, 256, 256) + start = np.array([loc - ha for loc, ha in zip(location, halo)]) + stop = np.array([loc + ha for loc, ha in zip(location, halo)]) + bb = tuple(slice(sta, sto) for sta, sto in zip(start, stop)) + + print("Loading tables ...") + _, ctbp2_table, _ = _get_synapse_tables() + ids = ctbp2_table.spot_id.values + coords = ctbp2_table[["z", "y", "x"]].values / resolution + mask = np.logical_and( + (coords > start[None, :]).all(axis=1), + (coords < stop[None, :]).all(axis=1), + ) + ids = ids[mask] + det_ctbp2 = coords[mask] + det_ctbp2 -= start[None, :] + print("Found", len(det_ctbp2), "detection in the crop") + + print("Loading image data ...") + s3_store, fs = get_s3_path( + f"{COCHLEA}/images/ome-zarr/CTBP2.ome.zarr", bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT + ) + f = zarr.open(s3_store, mode="r") + ctbp2 = f["s0"][bb] + + print("Loading prediction ...") + with z5py.File(pred_path_ctbp2, "r") as f: + pred_ctbp2 = f["prediction"][bb] + + v = napari.Viewer() + v.add_image(ctbp2) + v.add_image(pred_ctbp2) + v.add_points(det_ctbp2) + napari.run() + + +def main(): + # check_and_filter_synapses() + check_synapse_predictions() + + +if __name__ == "__main__": + main() diff --git a/scripts/more-annotations/extract_blocks_da_mlr99l.py b/scripts/more-annotations/extract_blocks_da_mlr99l.py new file mode 100644 index 0000000..398584b --- /dev/null +++ b/scripts/more-annotations/extract_blocks_da_mlr99l.py @@ -0,0 +1,26 @@ +from flamingo_tools.extract_block_util import extract_block + + +positions_with_signal = [ + [1215.074063453085, 912.697256780485, 1036.814204517708], + [1030.351117830933, 1262.3358840155736, 1123.2581736686361], + [1192.167776682008, 354.058713359485, 767.1544606203263], + [916.9294364078347, 754.7061965177552, 923.607923806173], +] +positions_without_signal = [ + [1383.4288658807268, 783.0008672288084, 467.5426478786816], +] + +halo = [256, 256, 64] + + +for pos in positions_with_signal + positions_without_signal: + extract_block( + input_path="M_LR_000099_L/images/ome-zarr/PV.ome.zarr", + coords=pos, + output_dir="./MLR99L_for_DA", + input_key="s0", + roi_halo=halo, + tif=True, + s3=True + ) diff --git a/scripts/more-annotations/extract_sgn_annotations.py b/scripts/more-annotations/extract_sgn_annotations.py new file mode 100644 index 0000000..bd300d9 --- /dev/null +++ b/scripts/more-annotations/extract_sgn_annotations.py @@ -0,0 +1,144 @@ +import os +from flamingo_tools.extract_block_util import extract_block + +RESOLUTION_LA_VISION = (1.887779, 1.887779, 3.000000) +RESOLUTION_FLAMINGO = (0.38, 0.38, 0.38) + +POSITIONS = [ + [2451.991261845771, 2497.0312568466725, 504.00000000000017], + [2364.0285060661868, 2104.541310616445, 684.2966391951725], + [2579.872281689804, 2266.294057961108, 532.8474622712272], + [2251.404321115024, 1972.6189003459485, 313.69577047550024], +] + +EMPTY_POSITIONS = [ + [3091.354274253545, 1396.2702622881343, 443.21051449917223], + [1052.8399693103918, 2180.579279395121, 437.81154147679354], + [3621.8731222257875, 1602.0695390382377, 620.0517925327181], +] + + +def download_lavision_crops(): + input_path = "LaVision-M04/images/ome-zarr/PV.ome.zarr" + input_key = "s0" + output_key = None + + output_folder = "./LA_VISION_M04" + os.makedirs(output_folder, exist_ok=True) + for pos in POSITIONS: + halo = [128, 128, 32] + extract_block( + input_path, pos, output_folder, input_key, output_key, RESOLUTION_LA_VISION, halo, + tif=True, s3=True, + ) + + output_folder = "./LA_VISION_M04_empty" + os.makedirs(output_folder, exist_ok=True) + for pos in EMPTY_POSITIONS: + halo = [128, 128, 32] + extract_block( + input_path, pos, output_folder, input_key, output_key, RESOLUTION_LA_VISION, halo, + tif=True, s3=True, + ) + + +def downscale_segmentation(): + # Scale levels: + # 0: 0.38 + # 1: 0.76 + # 2: 1.52 + cochleae_and_positions = { + "M_LR_000226_R": [ + [709.1792864323405, 277.94313087502087, 790.2787473759703], + [684.1211492422168, 551.0610808519966, 972.7784147188805], + [855.2911547522649, 893.5164525605765, 781.6745184537485], + [805.856486020322, 1087.1388983637446, 872.4092720709023], + ], + "M_LR_000226_L": [ + [728.811391169819, 787.126384246222, 765.5121274338735], + [310.8110214721421, 503.69151338122936, 433.37560298279783], + [409.56553632355974, 773.9536143837831, 926.4997632186463], + ], + "M_LR_000227_R": [ + [802.695142936733, 928.040906650113, 787.9300000000001], + [539.6960827733835, 837.7146969656125, 787.9300000000001], + [460.70492230292973, 366.6096043369565, 909.2776827283466], + ], + "M_LR_000227_L": [ + [583.3657358293676, 835.4967569151809, 754.4900000000004], + [846.4954841793519, 963.2748384826734, 706.7658788868116], + [927.8483264319711, 746.0723412831164, 578.0355803590329], + ], + } + + input_key = "s2" + halo = [128, 128, 128] + resolution = [1.52,] * 3 + output_key = None + + image_out_folder = "./downscaled_sgn_labels/images" + label_out_folder = "./downscaled_sgn_labels/labels" + + for cochlea, positions in cochleae_and_positions.items(): + print("Extracting blocks for", cochlea) + input_path = f"{cochlea}/images/ome-zarr/PV.ome.zarr" + seg_path = f"{cochlea}/images/ome-zarr/SGN_v2.ome.zarr" + for position in positions: + extract_block( + input_path, position, image_out_folder, input_key, output_key, resolution, halo, + tif=True, s3=True, scale_factor=(0.5, 1, 1), + ) + extract_block( + seg_path, position, label_out_folder, input_key, output_key, resolution, halo, + tif=True, s3=True, scale_factor=(0.5, 1, 1), + ) + + +# Note: consider different normalization strategy for these cochleae and normalize by local intensity +# rather than by global values. +# Also double check empty positions again and make sure they don't contain SGNs +def download_lavision_crops2(): + input_key = "s0" + output_key = None + + # Additional positions for LaVision annotations: + input_path = "LaVision-M04/images/ome-zarr/PV.ome.zarr" + new_positions_m04 = [ + [2031.0655170248258, 1925.206039671767, 249.14546086048554], + [2378.3720460599393, 2105.471228531872, 303.9285928812524], + [1619.3251178227529, 3444.7351705689553, 271.2360278843609], + [2358.2784398426843, 1503.2211953830192, 762.7325586759833], + ] + + output_folder = "./LA_VISION_M04_2" + os.makedirs(output_folder, exist_ok=True) + for pos in new_positions_m04: + halo = [128, 128, 32] + extract_block( + input_path, pos, output_folder, input_key, output_key, RESOLUTION_LA_VISION, halo, + tif=True, s3=True, + ) + + # Position in Marmoset: + new_positions_mar05 = [ + [2462.7875134103206, 2818.067344942212, 1177.1380214828991] + ] + input_path = "LaVision-Mar05/images/ome-zarr/PV.ome.zarr" + output_folder = "./LA_VISION_Mar05" + os.makedirs(output_folder, exist_ok=True) + for pos in new_positions_mar05: + halo = [128, 128, 32] + extract_block( + input_path, pos, output_folder, input_key, output_key, RESOLUTION_LA_VISION, halo, + tif=True, s3=True, + ) + + +def main(): + # download_lavision_crops() + # downscale_segmentation() + download_lavision_crops2() + + +if __name__ == "__main__": + main() diff --git a/scripts/more-annotations/extract_vglut3_gerbil.py b/scripts/more-annotations/extract_vglut3_gerbil.py new file mode 100644 index 0000000..19185d4 --- /dev/null +++ b/scripts/more-annotations/extract_vglut3_gerbil.py @@ -0,0 +1,82 @@ +import os +import json +from flamingo_tools.extract_block_util import extract_block + +# Segmentation G_EK_000233_L IHC_v5: +# Components: 1, 2, 3, 4, 5, 8 + + +def initial_blocks(): + blocks_to_annotate = [ + "[1157.0001356895104,1758.4586038866773,994.1494008046312]", + "[1257.4854619519856,1712.418054399143,942.993234707371]", + "[1329.892068232878,1421.8487165158099,712.9291398862247]", + "[1035.3286672774282,1844.919679510697,826.5595378176982]" + ] + + empty_blocks = [ + "[1066.8140315229548,1654.678601097876,994.1494008046308]", + "[1372.9314226188667,1698.1843589090392,805.5965454893357]", + "[1079.515087933512,1425.90033123735,1006.1353228190363]", + "[825.208612674945,565.9088211207202,1170.995860154057]", + ] + + input_path = "G_EK_000233_L/images/ome-zarr/Vglut3.ome.zarr" + input_key = "s0" + output_key = None + resolution = 0.38 + + output_folder = "./G_EK_000233_L_VGlut3" + os.makedirs(output_folder, exist_ok=True) + for coords in blocks_to_annotate: + halo = [196, 196, 48] + coords = json.loads(coords) + extract_block( + input_path, coords, output_folder, input_key, output_key, resolution, halo, + tif=True, s3=True, + ) + + output_folder = "./G_EK_000233_L_VGlut3_empty" + os.makedirs(output_folder, exist_ok=True) + for coords in empty_blocks: + coords = json.loads(coords) + extract_block( + input_path, coords, output_folder, input_key, output_key, resolution, halo, + tif=True, s3=True, + ) + + +def next_blocks(): + blocks_to_annotate = [ + "[1369.3745663030836,1518.1919905173781,629.7304794154277]", + "[1455.8392111279102,1678.0405530924381,706.4828408796568]", + "[787.5688223108835,937.4842970917134,254.5776808983996]", + "[673.778067707707,1047.544514258375,1573.6031222817312]", + "[593.1568182540034,1018.005901151162,276.4582768781532]", + "[962.1827956246404,1973.0958986758776,756.6812813123805]", + ] + + input_path = "G_EK_000233_L/images/ome-zarr/Vglut3.ome.zarr" + input_key = "s0" + output_key = None + resolution = 0.38 + + output_folder = "./G_EK_000233_L_VGlut3-round2" + os.makedirs(output_folder, exist_ok=True) + print("Exporting to", output_folder) + for coords in blocks_to_annotate: + halo = [196, 196, 48] + coords = json.loads(coords) + extract_block( + input_path, coords, output_folder, input_key, output_key, resolution, halo, + tif=True, s3=True, + ) + + +def main(): + # initial_blocks() + next_blocks() + + +if __name__ == "__main__": + main() diff --git a/scripts/training/train_distance_unet.py b/scripts/training/train_distance_unet.py index 98daefa..178bc81 100644 --- a/scripts/training/train_distance_unet.py +++ b/scripts/training/train_distance_unet.py @@ -39,6 +39,21 @@ def get_image_and_label_paths(root): return image_paths, label_paths +def get_image_and_label_paths_sep_folders(root): + image_paths = sorted(glob(os.path.join(root, "images", "**", "*.tif"), recursive=True)) + label_paths = sorted(glob(os.path.join(root, "labels", "**", "*.tif"), recursive=True)) + assert len(image_paths) == len(label_paths) + + # import imageio.v3 as imageio + # for imp, lp in zip(image_paths, label_paths): + # s1 = imageio.imread(imp).shape + # s2 = imageio.imread(lp).shape + # if s1 != s2: + # breakpoint() + + return image_paths, label_paths + + def select_paths(image_paths, label_paths, split, filter_empty): if filter_empty: image_paths = [imp for imp in image_paths if "empty" not in imp] @@ -60,8 +75,11 @@ def select_paths(image_paths, label_paths, split, filter_empty): return image_paths, label_paths -def get_loader(root, split, patch_shape, batch_size, filter_empty): - image_paths, label_paths = get_image_and_label_paths(root) +def get_loader(root, split, patch_shape, batch_size, filter_empty, separate_folders): + if separate_folders: + image_paths, label_paths = get_image_and_label_paths_sep_folders(root) + else: + image_paths, label_paths = get_image_and_label_paths(root) this_image_paths, this_label_paths = select_paths(image_paths, label_paths, split, filter_empty) assert len(this_image_paths) == len(this_label_paths) @@ -97,6 +115,7 @@ def main(): parser.add_argument( "--name", help="Optional name for the model to be trained. If not given the current date is used." ) + parser.add_argument("--separate_folders", action="store_true") args = parser.parse_args() root = args.root batch_size = args.batch_size @@ -106,14 +125,18 @@ def main(): # Parameters for training on A100. n_iterations = int(1e5) - patch_shape = (64, 128, 128) + patch_shape = (48, 128, 128) # The U-Net. model = get_3d_model() # Create the training loader with train and val set. - train_loader = get_loader(root, "train", patch_shape, batch_size, filter_empty=filter_empty) - val_loader = get_loader(root, "val", patch_shape, batch_size, filter_empty=filter_empty) + train_loader = get_loader( + root, "train", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders + ) + val_loader = get_loader( + root, "val", patch_shape, batch_size, filter_empty=filter_empty, separate_folders=args.separate_folders + ) if check_loaders: from torch_em.util.debug import check_loader