diff --git a/scripts/cooper/training/evaluate_compartments.py b/scripts/cooper/training/evaluate_compartments.py new file mode 100644 index 00000000..d3854223 --- /dev/null +++ b/scripts/cooper/training/evaluate_compartments.py @@ -0,0 +1,103 @@ +import os +import h5py +import numpy as np +import pandas as pd + +from synapse_net.inference.inference import get_model +from synapse_net.inference.compartments import segment_compartments +from skimage.segmentation import find_boundaries + +from elf.evaluation.matching import matching + +from train_compartments import get_paths_3d +from sklearn.model_selection import train_test_split + + +def run_prediction(paths): + output_folder = "./compartment_eval" + os.makedirs(output_folder, exist_ok=True) + + model = get_model("compartments") + for path in paths: + with h5py.File(path, "r") as f: + input_vol = f["raw"][:] + seg, pred = segment_compartments(input_vol, model=model, return_predictions=True) + fname = os.path.basename(path) + out = os.path.join(output_folder, fname) + with h5py.File(out, "a") as f: + f.create_dataset("seg", data=seg, compression="gzip") + f.create_dataset("pred", data=pred, compression="gzip") + + +def binary_recall(gt, pred): + tp = np.logical_and(gt, pred).sum() + fn = np.logical_and(gt, ~pred).sum() + return float(tp) / (tp + fn) if (tp + fn) else 0.0 + + +def run_evaluation(paths): + output_folder = "./compartment_eval" + + results = { + "name": [], + "recall-pred": [], + "recall-seg": [], + } + + for path in paths: + with h5py.File(path, "r") as f: + labels = f["labels/compartments"][:] + boundary_labels = find_boundaries(labels).astype("bool") + + fname = os.path.basename(path) + out = os.path.join(output_folder, fname) + with h5py.File(out, "a") as f: + seg, pred = f["seg"][:], f["pred"][:] + + recall_pred = binary_recall(boundary_labels, pred > 0.5) + recall_seg = matching(seg, labels)["recall"] + + results["name"].append(fname) + results["recall-pred"].append(recall_pred) + results["recall-seg"].append(recall_seg) + + results = pd.DataFrame(results) + print(results) + print(results[["recall-pred", "recall-seg"]].mean()) + + +def check_predictions(paths): + import napari + output_folder = "./compartment_eval" + + for path in paths: + with h5py.File(path, "r") as f: + raw = f["raw"][:] + labels = f["labels/compartments"][:] + boundary_labels = find_boundaries(labels) + + fname = os.path.basename(path) + out = os.path.join(output_folder, fname) + with h5py.File(out, "a") as f: + seg, pred = f["seg"][:], f["pred"][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_image(pred) + v.add_labels(labels) + v.add_labels(boundary_labels) + v.add_labels(seg) + napari.run() + + +def main(): + paths = get_paths_3d() + _, val_paths = train_test_split(paths, test_size=0.10, random_state=42) + + # run_prediction(val_paths) + run_evaluation(val_paths) + # check_predictions(val_paths) + + +if __name__ == "__main__": + main() diff --git a/scripts/cooper/training/train_compartments.py b/scripts/cooper/training/train_compartments.py index 5d568c5b..7b3e01b1 100644 --- a/scripts/cooper/training/train_compartments.py +++ b/scripts/cooper/training/train_compartments.py @@ -14,7 +14,6 @@ from synapse_net.training import supervised_training TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/compartments" -# TRAIN_ROOT = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/ground_truth/compartments/output/compartment_gt" # noqa def get_paths_2d(): diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py index 9cbdd143..47147d71 100644 --- a/synapse_net/inference/inference.py +++ b/synapse_net/inference/inference.py @@ -22,7 +22,7 @@ def _get_model_registry(): registry = { - "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", + "active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a", "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", "mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673", @@ -37,7 +37,7 @@ def _get_model_registry(): "vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d", } urls = { - "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", + "active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download", "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", "mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download", @@ -109,7 +109,7 @@ def get_model_training_resolution(model_type: str) -> Dict[str, float]: Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. """ resolutions = { - "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, + "active_zone": {"x": 1.38, "y": 1.38, "z": 1.38}, "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, "cristae": {"x": 1.44, "y": 1.44, "z": 1.44}, diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index 60900001..6b4e44f1 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -1,7 +1,9 @@ import argparse +import os from functools import partial import torch +import torch_em from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation from ..inference.util import inference_helper, parse_tiling @@ -155,7 +157,14 @@ def segmentation_cli(): if args.checkpoint is None: model = get_model(args.model) else: - model = torch.load(args.checkpoint, weights_only=False) + checkpoint_path = args.checkpoint + if checkpoint_path.endswith("best.pt"): + checkpoint_path = os.path.split(checkpoint_path)[0] + + if os.path.isdir(checkpoint_path): # Load the model from a torch_em checkpoint. + model = torch_em.util.load_model(checkpoint=checkpoint_path) + else: + model = torch.load(checkpoint_path, weights_only=False) assert model is not None, f"The model from {args.checkpoint} could not be loaded." is_2d = "2d" in args.model