Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 136 additions & 1 deletion flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import math
import multiprocessing as mp
import threading
from concurrent import futures
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import elf.parallel as parallel
import numpy as np
Expand All @@ -15,6 +16,9 @@
from scipy.spatial import distance
from scipy.spatial import cKDTree, ConvexHull
from skimage import measure
from skimage.filters import gaussian
from skimage.feature import peak_local_max
from skimage.segmentation import find_boundaries, watershed
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

Expand Down Expand Up @@ -732,3 +736,134 @@ def filter_cochlea_volume(
combined_dilated[combined_dilated > 0] = 1

return combined_dilated


def split_nonconvex_objects(
segmentation: np.typing.ArrayLike,
output: np.typing.ArrayLike,
segmentation_table: pd.DataFrame,
min_size: int,
resolution: Union[float, Sequence[float]],
height_map: Optional[np.typing.ArrayLike] = None,
component_labels: Optional[List[int]] = None,
n_threads: Optional[int] = None,
) -> Dict[int, List[int]]:
"""Split noncovex objects into multiple parts inplace.

Args:
segmentation:
output:
segmentation_table:
min_size:
resolution:
height_map:
component_labels:
n_threads:
"""
if isinstance(resolution, float):
resolution = [resolution] * 3
assert len(resolution) == 3
resolution = np.array(resolution)

lock = threading.Lock()
offset = len(segmentation_table)

def split_object(object_id):
nonlocal offset

row = segmentation_table[segmentation_table.label_id == object_id]
if row.n_pixels.values[0] < min_size:
# print(object_id, ": min-size")
return [object_id]

bb_min = np.array([
row.bb_min_z.values[0], row.bb_min_y.values[0], row.bb_min_x.values[0],
]) / resolution
bb_max = np.array([
row.bb_max_z.values[0], row.bb_max_y.values[0], row.bb_max_x.values[0],
]) / resolution

bb_min = np.maximum(bb_min.astype(int) - 1, np.array([0, 0, 0]))
bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape)))
bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max))

# This is due to segmentation artifacts.
bb_shape = bb_max - bb_min
if (bb_shape > 500).any():
print(object_id, "has a too large shape:", bb_shape)
return [object_id]

seg = segmentation[bb]
mask = ~find_boundaries(seg)
dist = distance_transform_edt(mask, sampling=resolution)

seg_mask = seg == object_id
dist[~seg_mask] = 0
dist = gaussian(dist, (0.6, 1.2, 1.2))
maxima = peak_local_max(dist, min_distance=3, exclude_border=True)

if len(maxima) == 1:
# print(object_id, ": max len")
return [object_id]

with lock:
old_offset = offset
offset += len(maxima)

seeds = np.zeros(seg.shape, dtype=int)
for i, pos in enumerate(maxima, 1):
seeds[tuple(pos)] = old_offset + i

if height_map is None:
hmap = dist.max() - dist
else:
hmap = height_map[bb]
new_seg = watershed(hmap, markers=seeds, mask=seg_mask)

seg_ids, sizes = np.unique(new_seg, return_counts=True)
seg_ids, sizes = seg_ids[1:], sizes[1:]

keep_ids = seg_ids[sizes > min_size]
if len(keep_ids) < 2:
# print(object_id, ": keep-id")
return [object_id]

elif len(keep_ids) != len(seg_ids):
new_seg[~np.isin(new_seg, keep_ids)] = 0
new_seg = watershed(hmap, markers=new_seg, mask=seg_mask)

with lock:
out = output[bb]
out[seg_mask] = new_seg[seg_mask]
output[bb] = out

# print(object_id, ":", len(keep_ids))
return keep_ids.tolist()

# import napari
# v = napari.Viewer()
# v.add_image(hmap)
# v.add_labels(seg)
# v.add_labels(new_seg)
# v.add_points(maxima)
# napari.run()

if component_labels is None:
object_ids = segmentation_table.label_id.values
else:
object_ids = segmentation_table[segmentation_table.component_labels.isin(component_labels)].label_id.values

if n_threads is None:
n_threads = mp.cpu_count()

# new_id_mapping = []
# for object_id in tqdm(object_ids, desc="Split non-convex objects"):
# new_id_mapping.append(split_object(object_id))

with futures.ThreadPoolExecutor(n_threads) as tp:
new_id_mapping = list(
tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects")
)

new_id_mapping = {object_id: mapped_ids for object_id, mapped_ids in zip(object_ids, new_id_mapping)}
return new_id_mapping
9 changes: 7 additions & 2 deletions flamingo_tools/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def get_supervised_loader(
image_key: Optional[str] = None,
label_key: Optional[str] = None,
n_samples: Optional[int] = None,
raw_transform: Optional[callable] = None,
anisotropy: Optional[float] = None,
) -> DataLoader:
"""Get a data loader for a supervised segmentation task.

Expand All @@ -39,19 +41,22 @@ def get_supervised_loader(
image_key: Internal path for the image data. This is only required for hdf5/zarr/n5 data.
image_key: Internal path for the label masks. This is only required for hdf5/zarr/n5 data.
n_samples: The number of samples to use for training.
raw_transform: Optional transformation for the raw data.
anisotropy: The anisotropy factor for distance target computation.

Returns:
The data loader.
"""
assert len(image_paths) == len(label_paths)
assert len(image_paths) > 0
sampling = None if anisotropy is None else (anisotropy, 1.0, 1.0)
label_transform = torch_em.transform.label.PerObjectDistanceTransform(
distances=True, boundary_distances=True, foreground=True,
distances=True, boundary_distances=True, foreground=True, sampling=sampling,
)
sampler = torch_em.data.sampler.MinInstanceSampler(p_reject=0.8)
loader = torch_em.default_segmentation_loader(
raw_paths=image_paths, raw_key=image_key, label_paths=label_paths, label_key=label_key,
batch_size=batch_size, patch_shape=patch_shape, label_transform=label_transform,
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler
n_samples=n_samples, num_workers=4, shuffle=True, sampler=sampler, raw_transform=raw_transform,
)
return loader
15 changes: 13 additions & 2 deletions scripts/figures/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _get_mapping(animal):
return bin_edges, bin_labels


def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False):
def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False, categorical=False):
# Get the mapping of frequencies to octave bands for the given species.
bin_edges, bin_labels = _get_mapping(animal)

Expand All @@ -34,7 +34,18 @@ def frequency_mapping(frequencies, values, animal="mouse", transduction_efficien
df["freq_khz"], bins=bin_edges, labels=bin_labels, right=False
)

if transduction_efficiency: # We compute the transduction efficiency per band.
if categorical:
assert not transduction_efficiency
categories = pd.unique(df.value)
num_tot = df.groupby("octave_band", observed=False).size()
value_by_band = {}
for cat in categories:
pos_cat = df[df.value == cat].groupby("octave_band", observed=False).size()
cat_by_band = (pos_cat / num_tot).reindex(bin_labels)
cat_by_band = cat_by_band.reset_index()
cat_by_band.columns = ["octave_band", "value"]
value_by_band[cat] = cat_by_band
elif transduction_efficiency: # We compute the transduction efficiency per band.
num_pos = df[df["value"] == 1].groupby("octave_band", observed=False).size()
num_tot = df[df["value"].isin([1, 2])].groupby("octave_band", observed=False).size()
value_by_band = (num_pos / num_tot).reindex(bin_labels)
Expand Down
50 changes: 50 additions & 0 deletions scripts/intensity_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse

import imageio.v3 as imageio
import numpy as np
from scipy.ndimage import binary_dilation, binary_closing, distance_transform_edt


def intensity_masking(image_path, seg_path, out_path, modulation_strength=10, dilation=2, view=False):
seg = imageio.imread(seg_path)
mask = binary_dilation(seg != 0, iterations=2)
mask = binary_closing(mask, iterations=4)

image = imageio.imread(image_path)
lo, hi = np.percentile(image, 2), np.percentile(image, 98)
print(lo, hi)
image_modulated = np.clip(image, lo, hi).astype("float32")
image_modulated -= lo
image_modulated /= image_modulated.max()

modulation_mask = distance_transform_edt(~mask)
modulation_mask /= modulation_mask.max()
modulation_mask = 1 - modulation_mask
modulation_mask[mask] = 1
modulation_mask = np.pow(modulation_mask, 3)
modulation_mask *= modulation_strength
image_modulated *= modulation_mask

if view:
import napari
v = napari.Viewer()
v.add_image(modulation_mask)
v.add_image(image, visible=False)
v.add_image(image_modulated)
v.add_labels(mask, visible=False)
napari.run()
return
imageio.imwrite(out_path, image_modulated, compression="zlib")


# image_path = "M_LR_000227_R/scale3/PV.tif"
# seg_path = "M_LR_000227_R/scale3/SGN_v2.tif"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("image_path")
parser.add_argument("seg_path")
parser.add_argument("out_path")
parser.add_argument("--view", "-v", action="store_true")
parser.add_argument("--dilation", type=int, default=2)
args = parser.parse_args()
intensity_masking(args.image_path, args.seg_path, args.out_path, view=args.view, dilation=args.dilation)
61 changes: 61 additions & 0 deletions scripts/la-vision/check_detections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import napari
import zarr


resolution = [3.0, 1.887779, 1.887779]
positions = [
[2002.95539395823, 1899.9032205156411, 264.7747008147759]
]


def _load_from_mobie(bb):
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr"
f = zarr.open(path, mode="r")
data = f["s0"][bb]
print(bb)

path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/SGN_detect-v1.ome.zarr"
f = zarr.open(path, mode="r")
seg = f["s0"][bb]

return data, seg


def _load_prediction(bb):
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/LaVision-M04/SGN_detect-v1/predictions.zarr"
f = zarr.open(path, mode="r")
data = f["prediction"][bb]
return data


def _load_prediction_debug():
path = "./debug-pred/pred-v5.h5"
with zarr.open(path, "r") as f:
pred = f["pred"][:]
return pred


def check_detection(position, halo=[32, 384, 384]):

bb = tuple(
slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo)
)

pv, detections_mobie = _load_from_mobie(bb)
# pred = _load_prediction(bb)
pred = _load_prediction_debug()

v = napari.Viewer()
v.add_image(pv)
v.add_image(pred)
v.add_labels(detections_mobie)
napari.run()


def main():
position = positions[0]
check_detection(position)


if __name__ == "__main__":
main()
62 changes: 62 additions & 0 deletions scripts/la-vision/debug_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
from functools import partial

import numpy as np
import torch
import zarr
from torch_em.transform.raw import standardize
from torch_em.util.prediction import predict_with_halo


resolution = [3.0, 1.887779, 1.887779]
positions = [
[2002.95539395823, 1899.9032205156411, 264.7747008147759]
]


def _load_from_mobie(bb):
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr"
f = zarr.open(path, mode="r")
data = f["s0"][bb]
return data


def run_prediction(position, halo=[32, 384, 384]):
bb = tuple(
slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo)
)
pv = _load_from_mobie(bb)
mean, std = np.mean(pv), np.std(pv)
print(mean, std)
preproc = partial(standardize, mean=mean, std=std)

block_shape = (24, 256, 256)
halo = (8, 64, 64)

model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-v1.pt"
model = torch.load(model_path, weights_only=False)

def postproc(x):
x = np.clip(x, 0, 1)
max_ = np.percentile(x, 99)
x = x / max_
return x

pred = predict_with_halo(pv, model, [0], block_shape, halo, preprocess=preproc, postprocess=postproc).squeeze()

pred_name = "pred-v5"
out_folder = "./debug-pred"
os.makedirs(out_folder, exist_ok=True)

out_path = os.path.join(out_folder, f"{pred_name}.h5")
with zarr.open(out_path, "w") as f:
f.create_dataset("pred", data=pred)


def main():
position = positions[0]
run_prediction(position)


if __name__ == "__main__":
main()
Loading
Loading