Skip to content

Commit e6f86fc

Browse files
committed
background mask for unsupervised training
2 parents 292e450 + 9c252ed commit e6f86fc

File tree

10 files changed

+414
-33
lines changed

10 files changed

+414
-33
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import h5py
3+
import numpy as np
4+
import pandas as pd
5+
6+
from synapse_net.inference.inference import get_model
7+
from synapse_net.inference.compartments import segment_compartments
8+
from skimage.segmentation import find_boundaries
9+
10+
from elf.evaluation.matching import matching
11+
12+
from train_compartments import get_paths_3d
13+
from sklearn.model_selection import train_test_split
14+
15+
16+
def run_prediction(paths):
17+
output_folder = "./compartment_eval"
18+
os.makedirs(output_folder, exist_ok=True)
19+
20+
model = get_model("compartments")
21+
for path in paths:
22+
with h5py.File(path, "r") as f:
23+
input_vol = f["raw"][:]
24+
seg, pred = segment_compartments(input_vol, model=model, return_predictions=True)
25+
fname = os.path.basename(path)
26+
out = os.path.join(output_folder, fname)
27+
with h5py.File(out, "a") as f:
28+
f.create_dataset("seg", data=seg, compression="gzip")
29+
f.create_dataset("pred", data=pred, compression="gzip")
30+
31+
32+
def binary_recall(gt, pred):
33+
tp = np.logical_and(gt, pred).sum()
34+
fn = np.logical_and(gt, ~pred).sum()
35+
return float(tp) / (tp + fn) if (tp + fn) else 0.0
36+
37+
38+
def run_evaluation(paths):
39+
output_folder = "./compartment_eval"
40+
41+
results = {
42+
"name": [],
43+
"recall-pred": [],
44+
"recall-seg": [],
45+
}
46+
47+
for path in paths:
48+
with h5py.File(path, "r") as f:
49+
labels = f["labels/compartments"][:]
50+
boundary_labels = find_boundaries(labels).astype("bool")
51+
52+
fname = os.path.basename(path)
53+
out = os.path.join(output_folder, fname)
54+
with h5py.File(out, "a") as f:
55+
seg, pred = f["seg"][:], f["pred"][:]
56+
57+
recall_pred = binary_recall(boundary_labels, pred > 0.5)
58+
recall_seg = matching(seg, labels)["recall"]
59+
60+
results["name"].append(fname)
61+
results["recall-pred"].append(recall_pred)
62+
results["recall-seg"].append(recall_seg)
63+
64+
results = pd.DataFrame(results)
65+
print(results)
66+
print(results[["recall-pred", "recall-seg"]].mean())
67+
68+
69+
def check_predictions(paths):
70+
import napari
71+
output_folder = "./compartment_eval"
72+
73+
for path in paths:
74+
with h5py.File(path, "r") as f:
75+
raw = f["raw"][:]
76+
labels = f["labels/compartments"][:]
77+
boundary_labels = find_boundaries(labels)
78+
79+
fname = os.path.basename(path)
80+
out = os.path.join(output_folder, fname)
81+
with h5py.File(out, "a") as f:
82+
seg, pred = f["seg"][:], f["pred"][:]
83+
84+
v = napari.Viewer()
85+
v.add_image(raw)
86+
v.add_image(pred)
87+
v.add_labels(labels)
88+
v.add_labels(boundary_labels)
89+
v.add_labels(seg)
90+
napari.run()
91+
92+
93+
def main():
94+
paths = get_paths_3d()
95+
_, val_paths = train_test_split(paths, test_size=0.10, random_state=42)
96+
97+
# run_prediction(val_paths)
98+
run_evaluation(val_paths)
99+
# check_predictions(val_paths)
100+
101+
102+
if __name__ == "__main__":
103+
main()

scripts/cooper/training/train_compartments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from synapse_net.training import supervised_training
1515

1616
TRAIN_ROOT = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/ground_truth/compartments"
17-
# TRAIN_ROOT = "/home/pape/Work/my_projects/synaptic-reconstruction/scripts/cooper/ground_truth/compartments/output/compartment_gt" # noqa
1817

1918

2019
def get_paths_2d():

synapse_net/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.0"
1+
__version__ = "0.3.0"

synapse_net/ground_truth/vesicles.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def extract_vesicle_training_data(
227227
relative_path = os.path.relpath(file_path, data_folder)
228228

229229
if to_label_path is None:
230-
imod_path = os.path.join(gt_folder, relative_path.replace(Path(relative_path).suffix, ".imod"))
230+
imod_path = os.path.join(gt_folder, relative_path.replace(Path(relative_path).suffix, ".mod"))
231231
else:
232232
imod_path = to_label_path(gt_folder, relative_path)
233233

synapse_net/inference/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
def _get_model_registry():
2424
registry = {
25-
"active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0",
25+
"active_zone": "c23652a8fe06daa113546af6d3200c4c1dcc79917056c6ed7357b8c93548372a",
2626
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
2727
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
2828
"mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
@@ -37,7 +37,7 @@ def _get_model_registry():
3737
"vesicles_3d_innerear": "924f0f7cfb648a3a6931c1d48d8b1fdc6c0c0d2cb3330fe2cae49d13e7c3b69d",
3838
}
3939
urls = {
40-
"active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download",
40+
"active_zone": "https://owncloud.gwdg.de/index.php/s/wpea9FH9waG4zJd/download",
4141
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
4242
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
4343
"mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
@@ -109,7 +109,7 @@ def get_model_training_resolution(model_type: str) -> Dict[str, float]:
109109
Mapping of axis (x, y, z) to the voxel size (in nm) of that axis.
110110
"""
111111
resolutions = {
112-
"active_zone": {"x": 1.44, "y": 1.44, "z": 1.44},
112+
"active_zone": {"x": 1.38, "y": 1.38, "z": 1.38},
113113
"compartments": {"x": 3.47, "y": 3.47, "z": 3.47},
114114
"mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07},
115115
"cristae": {"x": 1.44, "y": 1.44, "z": 1.44},
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import tempfile
3+
from typing import Dict, List, Optional
4+
5+
import elf.parallel as parallel
6+
import numpy as np
7+
import torch
8+
9+
from elf.io import open_file
10+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
11+
from elf.wrapper.base import MultiTransformationWrapper
12+
from elf.wrapper.resized_volume import ResizedVolume
13+
from numpy.typing import ArrayLike
14+
from synapse_net.inference.util import get_prediction
15+
16+
17+
class SelectChannel(SimpleTransformationWrapper):
18+
"""Wrapper to select a chanel from an array-like dataset object.
19+
20+
Args:
21+
volume: The array-like input dataset.
22+
channel: The channel that will be selected.
23+
"""
24+
def __init__(self, volume: np.typing.ArrayLike, channel: int):
25+
self.channel = channel
26+
super().__init__(volume, lambda x: x[self.channel], with_channels=True)
27+
28+
@property
29+
def shape(self):
30+
return self._volume.shape[1:]
31+
32+
@property
33+
def chunks(self):
34+
return self._volume.chunks[1:]
35+
36+
@property
37+
def ndim(self):
38+
return self._volume.ndim - 1
39+
40+
41+
def _run_segmentation(pred, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape):
42+
# Create wrappers for selecting the foreground and the boundary channel.
43+
foreground = SelectChannel(pred, 0)
44+
boundaries = SelectChannel(pred, 1)
45+
46+
# Create wrappers for subtracting and thresholding boundary subtracted from the foreground.
47+
# And then compute the seeds based on this.
48+
seed_input = ThresholdWrapper(
49+
MultiTransformationWrapper(np.subtract, foreground, boundaries), seed_threshold
50+
)
51+
parallel.label(seed_input, seeds, verbose=verbose, block_shape=chunks)
52+
53+
# Run watershed to extend back from the seeds to the boundaries.
54+
mask = ThresholdWrapper(foreground, 0.5)
55+
56+
# Resize if necessary.
57+
if original_shape is not None:
58+
boundaries = ResizedVolume(boundaries, original_shape, order=1)
59+
seeds = ResizedVolume(seeds, original_shape, order=0)
60+
mask = ResizedVolume(mask, original_shape, order=0)
61+
62+
parallel.seeded_watershed(
63+
boundaries, seeds=seeds, out=output, verbose=verbose, mask=mask, block_shape=chunks, halo=3 * (16,)
64+
)
65+
66+
# Run the size filter.
67+
if min_size > 0:
68+
parallel.size_filter(output, output, min_size=min_size, verbose=verbose, block_shape=chunks)
69+
70+
71+
def scalable_segmentation(
72+
input_: ArrayLike,
73+
output: ArrayLike,
74+
model: torch.nn.Module,
75+
tiling: Optional[Dict[str, Dict[str, int]]] = None,
76+
scale: Optional[List[float]] = None,
77+
seed_threshold: float = 0.5,
78+
min_size: int = 500,
79+
prediction: Optional[ArrayLike] = None,
80+
verbose: bool = True,
81+
mask: Optional[ArrayLike] = None,
82+
) -> None:
83+
"""Run segmentation based on a prediction with foreground and boundary channel.
84+
85+
This function first subtracts the boundary prediction from the foreground prediction,
86+
then applies a threshold, connected components, and a watershed to fit the components
87+
back to the foreground. All processing steps are implemented in a scalable fashion,
88+
so that the function runs for large input volumes.
89+
90+
Args:
91+
input_: The input data.
92+
output: The array for storing the output segmentation.
93+
Can be a numpy array, a zarr array, or similar.
94+
model: The model for prediction.
95+
tiling: The tiling configuration for the prediction.
96+
scale: The scale factor to use for rescaling the input volume before prediction.
97+
seed_threshold: The threshold applied before computing connected components.
98+
min_size: The minimum size of a vesicle to be considered.
99+
prediction: The array for storing the prediction.
100+
If given, this can be a numpy array, a zarr array, or similar
101+
If not given will be stored in a temporary n5 array.
102+
verbose: Whether to print timing information.
103+
"""
104+
if mask is not None:
105+
raise NotImplementedError
106+
assert model.out_channels == 2
107+
108+
# Create a temporary directory for storing the predictions.
109+
chunks = (128,) * 3
110+
with tempfile.TemporaryDirectory() as tmp_dir:
111+
112+
if scale is None or np.allclose(scale, 1.0, atol=1e-3):
113+
original_shape = None
114+
else:
115+
original_shape = input_.shape
116+
new_shape = tuple(int(sh * sc) for sh, sc in zip(input_.shape, scale))
117+
input_ = ResizedVolume(input_, shape=new_shape, order=1)
118+
119+
if prediction is None:
120+
# Create the dataset for storing the prediction.
121+
tmp_pred = os.path.join(tmp_dir, "prediction.n5")
122+
f = open_file(tmp_pred, mode="a")
123+
pred_shape = (2,) + input_.shape
124+
pred_chunks = (1,) + chunks
125+
prediction = f.create_dataset("pred", shape=pred_shape, dtype="float32", chunks=pred_chunks)
126+
else:
127+
assert prediction.shape[0] == 2
128+
assert prediction.shape[1:] == input_.shape
129+
130+
# Create temporary storage for the seeds.
131+
tmp_seeds = os.path.join(tmp_dir, "seeds.n5")
132+
f = open_file(tmp_seeds, mode="a")
133+
seeds = f.create_dataset("seeds", shape=input_.shape, dtype="uint64", chunks=chunks)
134+
135+
# Run prediction and segmentation.
136+
get_prediction(input_, prediction=prediction, tiling=tiling, model=model, verbose=verbose)
137+
_run_segmentation(prediction, output, seeds, chunks, seed_threshold, min_size, verbose, original_shape)

0 commit comments

Comments
 (0)