Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def prediction_impl(
scale,
block_shape,
halo,
output_channels=3,
apply_postprocessing=True,
prediction_instances=1,
slurm_task_id=0,
mean=None,
Expand All @@ -75,7 +77,10 @@ def prediction_impl(
model = torch.load(model_path, weights_only=False)

mask_path = os.path.join(output_folder, "mask.zarr")
image_mask = z5py.File(mask_path, "r")["mask"]
if os.path.exists(mask_path):
image_mask = z5py.File(mask_path, "r")["mask"]
else:
image_mask = None

input_ = read_image_data(input_path, input_key)
chunks = getattr(input_, "chunks", (64, 64, 64))
Expand Down Expand Up @@ -122,10 +127,20 @@ def preprocess(raw):
raw /= std
return raw

# Smooth the distance prediction channel.
def postprocess(x):
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
return x
if apply_postprocessing:
# Smooth the distance prediction channel.
def postprocess(x):
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
return x
else:
postprocess = None if output_channels > 1 else lambda x: x.squeeze()

if output_channels > 1:
output_shape = (output_channels,) + input_.shape
output_chunks = (1,) + block_shape
else:
output_shape = input_.shape
output_chunks = block_shape

shape = input_.shape
ndim = len(shape)
Expand All @@ -142,8 +157,8 @@ def postprocess(x):
with open_file(output_path, "a") as f:
output = f.require_dataset(
"prediction",
shape=(3,) + input_.shape,
chunks=(1,) + block_shape,
shape=output_shape,
chunks=output_chunks,
compression="gzip",
dtype="float32",
)
Expand Down
1 change: 1 addition & 0 deletions scripts/synapse_marker_detection/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
data/
114 changes: 114 additions & 0 deletions scripts/synapse_marker_detection/check_synapse_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
from glob import glob
from pathlib import Path

import h5py
import imageio.v3 as imageio
import napari
import numpy as np
import pandas as pd
import zarr

# from skimage.feature import blob_dog
from skimage.feature import peak_local_max
from torch_em.util import load_model
from torch_em.util.prediction import predict_with_halo
from train_synapse_detection import get_paths
from tqdm import tqdm

# INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
INPUT_ROOT = "./data/test_crops"
OUTPUT_ROOT = "./predictions"
DETECTION_OUT_ROOT = "./detections"


def run_prediction(val_image):
model = load_model("./checkpoints/synapse_detection_v1")
block_shape = (32, 384, 384)
halo = (8, 64, 64)
pred = predict_with_halo(val_image, model, [0], block_shape, halo)
return pred.squeeze()


def require_prediction(image_data, output_path):
key = "prediction"
if os.path.exists(output_path):
with h5py.File(output_path, "r") as f:
pred = f[key][:]
else:
pred = run_prediction(image_data)
with h5py.File(output_path, "w") as f:
f.create_dataset(key, data=pred, compression="gzip")
return pred


def run_postprocessing(pred):
# print("Running local max ...")
# coords = blob_dog(pred)
coords = peak_local_max(pred, min_distance=2, threshold_abs=0.5)
# print("... done")
return coords


def visualize_results(image_data, pred, coords=None, val_coords=None, title=None):
v = napari.Viewer()
v.add_image(image_data)
pred = pred.clip(0, pred.max())
v.add_image(pred)
if coords is not None:
v.add_points(coords, name="predicted_synapses", face_color="yellow")
if val_coords is not None:
v.add_points(val_coords, face_color="green", name="synapse_annotations")
if title is not None:
v.title = title
napari.run()


def check_val_image():
val_paths, _ = get_paths("val")
val_path = val_paths[0]
val_image = zarr.open(val_path)["raw"][:]

os.makedirs(os.path.join(OUTPUT_ROOT, "val"), exist_ok=True)
output_path = os.path.join(OUTPUT_ROOT, "val", os.path.basename(val_path).replace(".zarr", ".h5"))
pred = require_prediction(val_image, output_path)

visualize_results(val_image, pred)


def check_new_images(view=False, save_detection=False):
inputs = glob(os.path.join(INPUT_ROOT, "*.tif"))
output_folder = os.path.join(OUTPUT_ROOT, "new_crops")
os.makedirs(output_folder, exist_ok=True)
for path in tqdm(inputs):
print(path)
name = os.path.basename(path)
if name == "M_AMD_58L_avgblendfused_RibB.tif":
continue
image_data = imageio.imread(path)
output_path = os.path.join(output_folder, name.replace(".tif", ".h5"))
# if not os.path.exists(output_path):
# continue
pred = require_prediction(image_data, output_path)
if view or save_detection:
coords = run_postprocessing(pred)
if view:
print("Number of synapses:", len(coords))
visualize_results(image_data, pred, coords=coords, title=name)
if save_detection:
os.makedirs(DETECTION_OUT_ROOT, exist_ok=True)
coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1)
coords = pd.DataFrame(coords, columns=["index", "axis-0", "axis-1", "axis-2"])
fname = Path(path).stem
detection_save_path = os.path.join(DETECTION_OUT_ROOT, f"{fname}.csv")
coords.to_csv(detection_save_path, index=False)


# TODO update to support post-processing and showing annotations for the val data
def main():
# check_val_image()
check_new_images(view=False, save_detection=True)


if __name__ == "__main__":
main()
196 changes: 196 additions & 0 deletions scripts/synapse_marker_detection/detection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import numpy as np
import pandas as pd
import torch
import zarr

from skimage.filters import gaussian
from torch_em.util import ensure_tensor_with_channels


# Process labels stored in json napari style.
# I don't actually think that we need the epsilon here, but will leave it for now.
def process_labels(label_path, shape, sigma, eps, bb=None):
points = pd.read_csv(label_path)

if bb:
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
labels = np.zeros(restricted_shape, dtype="float32")
shape = restricted_shape
else:
labels = np.zeros(shape, dtype="float32")

assert len(points.columns) == len(shape)
z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"]
if bb is not None:
z_coords -= z_min
y_coords -= y_min
x_coords -= x_min
mask = np.logical_and.reduce([
np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)),
np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)),
np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)),
])
z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask]

coords = tuple(
np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip(
(z_coords, y_coords, x_coords), shape
)
)

labels[coords] = 1
labels = gaussian(labels, sigma)
# TODO better normalization?
labels /= (labels.max() + 1e-7)
labels *= 4
return labels


class DetectionDataset(torch.utils.data.Dataset):
max_sampling_attempts = 500

@staticmethod
def compute_len(shape, patch_shape):
if patch_shape is None:
return 1
else:
n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
return n_samples

def __init__(
self,
raw_path,
label_path,
patch_shape,
raw_key,
raw_transform=None,
label_transform=None,
transform=None,
dtype=torch.float32,
label_dtype=torch.float32,
n_samples=None,
sampler=None,
eps=1e-8,
sigma=None,
**kwargs,
):
self.raw_path = raw_path
self.label_path = label_path
self.raw_key = raw_key
self._ndim = 3

assert len(patch_shape) == self._ndim
self.patch_shape = patch_shape

self.raw_transform = raw_transform
self.label_transform = label_transform
self.transform = transform
self.sampler = sampler

self.dtype = dtype
self.label_dtype = label_dtype

self.eps = eps
self.sigma = sigma

with zarr.open(self.raw_path, "r") as f:
self.shape = f[self.raw_key].shape

if n_samples is None:
self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
else:
self._len = n_samples

def __len__(self):
return self._len

@property
def ndim(self):
return self._ndim

def _sample_bounding_box(self, shape):
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
raise NotImplementedError(
f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}"
)
bb_start = [
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
for sh, psh in zip(shape, self.patch_shape)
]
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))

def _get_sample(self, index):
raw, label_path = self.raw_path, self.label_path

raw = zarr.open(raw)[self.raw_key]
shape = raw.shape

bb = self._sample_bounding_box(shape)
label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb)

have_raw_channels = raw.ndim == 4 # 3D with channels
have_label_channels = label.ndim == 4
if have_label_channels:
raise NotImplementedError("Multi-channel labels are not supported.")

prefix_box = tuple()
if have_raw_channels:
if shape[-1] < 16:
shape = shape[:-1]
else:
shape = shape[1:]
prefix_box = (slice(None), )

raw_patch = np.array(raw[prefix_box + bb])
label_patch = np.array(label)

if self.sampler is not None:
assert False, "Sampler not implemented"
# sample_id = 0
# while not self.sampler(raw_patch, label_patch):
# bb = self._sample_bounding_box(shape)
# raw_patch = np.array(raw[prefix_box + bb])
# label_patch = np.array(label[bb])
# sample_id += 1
# if sample_id > self.max_sampling_attempts:
# raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")

if have_raw_channels and len(prefix_box) == 0:
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width

return raw_patch, label_patch

def __getitem__(self, index):
raw, labels = self._get_sample(index)
# initial_label_dtype = labels.dtype

if self.raw_transform is not None:
raw = self.raw_transform(raw)

if self.label_transform is not None:
labels = self.label_transform(labels)

if self.transform is not None:
raw, labels = self.transform(raw, labels)

raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
return raw, labels


if __name__ == "__main__":
import napari

raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr"
label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv"

f = zarr.open(raw_path, "r")
raw = f["raw"][:]

labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7)

v = napari.Viewer()
v.add_image(raw)
v.add_image(labels)
napari.run()
Loading