Skip to content

Commit 4af55c2

Browse files
Support for training and detection of synaptic spots (#24)
Implement training and prediction for synaptic spots
1 parent cf3ecda commit 4af55c2

File tree

7 files changed

+553
-7
lines changed

7 files changed

+553
-7
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def prediction_impl(
6060
scale,
6161
block_shape,
6262
halo,
63+
output_channels=3,
64+
apply_postprocessing=True,
6365
prediction_instances=1,
6466
slurm_task_id=0,
6567
mean=None,
@@ -75,7 +77,10 @@ def prediction_impl(
7577
model = torch.load(model_path, weights_only=False)
7678

7779
mask_path = os.path.join(output_folder, "mask.zarr")
78-
image_mask = z5py.File(mask_path, "r")["mask"]
80+
if os.path.exists(mask_path):
81+
image_mask = z5py.File(mask_path, "r")["mask"]
82+
else:
83+
image_mask = None
7984

8085
input_ = read_image_data(input_path, input_key)
8186
chunks = getattr(input_, "chunks", (64, 64, 64))
@@ -122,10 +127,20 @@ def preprocess(raw):
122127
raw /= std
123128
return raw
124129

125-
# Smooth the distance prediction channel.
126-
def postprocess(x):
127-
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
128-
return x
130+
if apply_postprocessing:
131+
# Smooth the distance prediction channel.
132+
def postprocess(x):
133+
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
134+
return x
135+
else:
136+
postprocess = None if output_channels > 1 else lambda x: x.squeeze()
137+
138+
if output_channels > 1:
139+
output_shape = (output_channels,) + input_.shape
140+
output_chunks = (1,) + block_shape
141+
else:
142+
output_shape = input_.shape
143+
output_chunks = block_shape
129144

130145
shape = input_.shape
131146
ndim = len(shape)
@@ -142,8 +157,8 @@ def postprocess(x):
142157
with open_file(output_path, "a") as f:
143158
output = f.require_dataset(
144159
"prediction",
145-
shape=(3,) + input_.shape,
146-
chunks=(1,) + block_shape,
160+
shape=output_shape,
161+
chunks=output_chunks,
147162
compression="gzip",
148163
dtype="float32",
149164
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data/
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import h5py
6+
import imageio.v3 as imageio
7+
import napari
8+
import numpy as np
9+
import pandas as pd
10+
import zarr
11+
12+
# from skimage.feature import blob_dog
13+
from skimage.feature import peak_local_max
14+
from torch_em.util import load_model
15+
from torch_em.util.prediction import predict_with_halo
16+
from train_synapse_detection import get_paths
17+
from tqdm import tqdm
18+
19+
# INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
20+
INPUT_ROOT = "./data/test_crops"
21+
OUTPUT_ROOT = "./predictions"
22+
DETECTION_OUT_ROOT = "./detections"
23+
24+
25+
def run_prediction(val_image):
26+
model = load_model("./checkpoints/synapse_detection_v1")
27+
block_shape = (32, 384, 384)
28+
halo = (8, 64, 64)
29+
pred = predict_with_halo(val_image, model, [0], block_shape, halo)
30+
return pred.squeeze()
31+
32+
33+
def require_prediction(image_data, output_path):
34+
key = "prediction"
35+
if os.path.exists(output_path):
36+
with h5py.File(output_path, "r") as f:
37+
pred = f[key][:]
38+
else:
39+
pred = run_prediction(image_data)
40+
with h5py.File(output_path, "w") as f:
41+
f.create_dataset(key, data=pred, compression="gzip")
42+
return pred
43+
44+
45+
def run_postprocessing(pred):
46+
# print("Running local max ...")
47+
# coords = blob_dog(pred)
48+
coords = peak_local_max(pred, min_distance=2, threshold_abs=0.5)
49+
# print("... done")
50+
return coords
51+
52+
53+
def visualize_results(image_data, pred, coords=None, val_coords=None, title=None):
54+
v = napari.Viewer()
55+
v.add_image(image_data)
56+
pred = pred.clip(0, pred.max())
57+
v.add_image(pred)
58+
if coords is not None:
59+
v.add_points(coords, name="predicted_synapses", face_color="yellow")
60+
if val_coords is not None:
61+
v.add_points(val_coords, face_color="green", name="synapse_annotations")
62+
if title is not None:
63+
v.title = title
64+
napari.run()
65+
66+
67+
def check_val_image():
68+
val_paths, _ = get_paths("val")
69+
val_path = val_paths[0]
70+
val_image = zarr.open(val_path)["raw"][:]
71+
72+
os.makedirs(os.path.join(OUTPUT_ROOT, "val"), exist_ok=True)
73+
output_path = os.path.join(OUTPUT_ROOT, "val", os.path.basename(val_path).replace(".zarr", ".h5"))
74+
pred = require_prediction(val_image, output_path)
75+
76+
visualize_results(val_image, pred)
77+
78+
79+
def check_new_images(view=False, save_detection=False):
80+
inputs = glob(os.path.join(INPUT_ROOT, "*.tif"))
81+
output_folder = os.path.join(OUTPUT_ROOT, "new_crops")
82+
os.makedirs(output_folder, exist_ok=True)
83+
for path in tqdm(inputs):
84+
print(path)
85+
name = os.path.basename(path)
86+
if name == "M_AMD_58L_avgblendfused_RibB.tif":
87+
continue
88+
image_data = imageio.imread(path)
89+
output_path = os.path.join(output_folder, name.replace(".tif", ".h5"))
90+
# if not os.path.exists(output_path):
91+
# continue
92+
pred = require_prediction(image_data, output_path)
93+
if view or save_detection:
94+
coords = run_postprocessing(pred)
95+
if view:
96+
print("Number of synapses:", len(coords))
97+
visualize_results(image_data, pred, coords=coords, title=name)
98+
if save_detection:
99+
os.makedirs(DETECTION_OUT_ROOT, exist_ok=True)
100+
coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1)
101+
coords = pd.DataFrame(coords, columns=["index", "axis-0", "axis-1", "axis-2"])
102+
fname = Path(path).stem
103+
detection_save_path = os.path.join(DETECTION_OUT_ROOT, f"{fname}.csv")
104+
coords.to_csv(detection_save_path, index=False)
105+
106+
107+
# TODO update to support post-processing and showing annotations for the val data
108+
def main():
109+
# check_val_image()
110+
check_new_images(view=False, save_detection=True)
111+
112+
113+
if __name__ == "__main__":
114+
main()
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import numpy as np
2+
import pandas as pd
3+
import torch
4+
import zarr
5+
6+
from skimage.filters import gaussian
7+
from torch_em.util import ensure_tensor_with_channels
8+
9+
10+
# Process labels stored in json napari style.
11+
# I don't actually think that we need the epsilon here, but will leave it for now.
12+
def process_labels(label_path, shape, sigma, eps, bb=None):
13+
points = pd.read_csv(label_path)
14+
15+
if bb:
16+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
17+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
18+
labels = np.zeros(restricted_shape, dtype="float32")
19+
shape = restricted_shape
20+
else:
21+
labels = np.zeros(shape, dtype="float32")
22+
23+
assert len(points.columns) == len(shape)
24+
z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"]
25+
if bb is not None:
26+
z_coords -= z_min
27+
y_coords -= y_min
28+
x_coords -= x_min
29+
mask = np.logical_and.reduce([
30+
np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)),
31+
np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)),
32+
np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)),
33+
])
34+
z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask]
35+
36+
coords = tuple(
37+
np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip(
38+
(z_coords, y_coords, x_coords), shape
39+
)
40+
)
41+
42+
labels[coords] = 1
43+
labels = gaussian(labels, sigma)
44+
# TODO better normalization?
45+
labels /= (labels.max() + 1e-7)
46+
labels *= 4
47+
return labels
48+
49+
50+
class DetectionDataset(torch.utils.data.Dataset):
51+
max_sampling_attempts = 500
52+
53+
@staticmethod
54+
def compute_len(shape, patch_shape):
55+
if patch_shape is None:
56+
return 1
57+
else:
58+
n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
59+
return n_samples
60+
61+
def __init__(
62+
self,
63+
raw_path,
64+
label_path,
65+
patch_shape,
66+
raw_key,
67+
raw_transform=None,
68+
label_transform=None,
69+
transform=None,
70+
dtype=torch.float32,
71+
label_dtype=torch.float32,
72+
n_samples=None,
73+
sampler=None,
74+
eps=1e-8,
75+
sigma=None,
76+
**kwargs,
77+
):
78+
self.raw_path = raw_path
79+
self.label_path = label_path
80+
self.raw_key = raw_key
81+
self._ndim = 3
82+
83+
assert len(patch_shape) == self._ndim
84+
self.patch_shape = patch_shape
85+
86+
self.raw_transform = raw_transform
87+
self.label_transform = label_transform
88+
self.transform = transform
89+
self.sampler = sampler
90+
91+
self.dtype = dtype
92+
self.label_dtype = label_dtype
93+
94+
self.eps = eps
95+
self.sigma = sigma
96+
97+
with zarr.open(self.raw_path, "r") as f:
98+
self.shape = f[self.raw_key].shape
99+
100+
if n_samples is None:
101+
self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
102+
else:
103+
self._len = n_samples
104+
105+
def __len__(self):
106+
return self._len
107+
108+
@property
109+
def ndim(self):
110+
return self._ndim
111+
112+
def _sample_bounding_box(self, shape):
113+
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
114+
raise NotImplementedError(
115+
f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}"
116+
)
117+
bb_start = [
118+
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
119+
for sh, psh in zip(shape, self.patch_shape)
120+
]
121+
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
122+
123+
def _get_sample(self, index):
124+
raw, label_path = self.raw_path, self.label_path
125+
126+
raw = zarr.open(raw)[self.raw_key]
127+
shape = raw.shape
128+
129+
bb = self._sample_bounding_box(shape)
130+
label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb)
131+
132+
have_raw_channels = raw.ndim == 4 # 3D with channels
133+
have_label_channels = label.ndim == 4
134+
if have_label_channels:
135+
raise NotImplementedError("Multi-channel labels are not supported.")
136+
137+
prefix_box = tuple()
138+
if have_raw_channels:
139+
if shape[-1] < 16:
140+
shape = shape[:-1]
141+
else:
142+
shape = shape[1:]
143+
prefix_box = (slice(None), )
144+
145+
raw_patch = np.array(raw[prefix_box + bb])
146+
label_patch = np.array(label)
147+
148+
if self.sampler is not None:
149+
assert False, "Sampler not implemented"
150+
# sample_id = 0
151+
# while not self.sampler(raw_patch, label_patch):
152+
# bb = self._sample_bounding_box(shape)
153+
# raw_patch = np.array(raw[prefix_box + bb])
154+
# label_patch = np.array(label[bb])
155+
# sample_id += 1
156+
# if sample_id > self.max_sampling_attempts:
157+
# raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
158+
159+
if have_raw_channels and len(prefix_box) == 0:
160+
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width
161+
162+
return raw_patch, label_patch
163+
164+
def __getitem__(self, index):
165+
raw, labels = self._get_sample(index)
166+
# initial_label_dtype = labels.dtype
167+
168+
if self.raw_transform is not None:
169+
raw = self.raw_transform(raw)
170+
171+
if self.label_transform is not None:
172+
labels = self.label_transform(labels)
173+
174+
if self.transform is not None:
175+
raw, labels = self.transform(raw, labels)
176+
177+
raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
178+
labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
179+
return raw, labels
180+
181+
182+
if __name__ == "__main__":
183+
import napari
184+
185+
raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr"
186+
label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv"
187+
188+
f = zarr.open(raw_path, "r")
189+
raw = f["raw"][:]
190+
191+
labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7)
192+
193+
v = napari.Viewer()
194+
v.add_image(raw)
195+
v.add_image(labels)
196+
napari.run()

0 commit comments

Comments
 (0)