Skip to content

Commit 0534db2

Browse files
Merge branch 'intensity-masking' of https://github.com/computational-cell-analytics/flamingo-tools into intensity-masking
2 parents 663d82e + e698bb1 commit 0534db2

File tree

3 files changed

+174
-6
lines changed

3 files changed

+174
-6
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
from glob import glob
3+
4+
import imageio.v3 as imageio
5+
import napari
6+
import numpy as np
7+
from skimage.measure import regionprops
8+
9+
10+
def main():
11+
image_files = sorted(glob("la-vision-sgn-new/images/*.tif"))
12+
label_files = sorted(glob("la-vision-sgn-new/segmentation-postprocessed/*.tif"))
13+
14+
for imf, lf in zip(image_files, label_files):
15+
im = imageio.imread(imf)
16+
labels = imageio.imread(lf)
17+
18+
props = regionprops(labels)
19+
centers = np.array([prop.centroid for prop in props])
20+
21+
name = os.path.basename(imf)
22+
print(name)
23+
24+
v = napari.Viewer()
25+
v.add_image(im)
26+
v.add_labels(labels)
27+
v.add_points(centers, size=5, out_of_slice_display=True)
28+
v.title = name
29+
napari.run()
30+
31+
32+
if __name__ == "__main__":
33+
main()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os
2+
import sys
3+
import json
4+
from glob import glob
5+
6+
from sklearn.model_selection import train_test_split
7+
8+
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
9+
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
10+
sys.path.append("../synapse_marker_detection")
11+
12+
from utils.training.training import supervised_training # noqa
13+
from detection_dataset import DetectionDataset, MinPointSampler # noqa
14+
15+
ROOT = "./la-vision-sgn-new" # noqa
16+
17+
TRAIN = os.path.join(ROOT, "images")
18+
TRAIN_EMPTY = os.path.join(ROOT, "empty_images")
19+
20+
LABEL = os.path.join(ROOT, "centroids")
21+
LABEL_EMPTY = os.path.join(ROOT, "empty_centroids")
22+
23+
24+
def _get_paths(split, train_folder, label_folder, n=None):
25+
image_paths = sorted(glob(os.path.join(train_folder, "*.tif")))
26+
label_paths = sorted(glob(os.path.join(label_folder, "*.csv")))
27+
assert len(image_paths) == len(label_paths)
28+
if n is not None:
29+
image_paths, label_paths = image_paths[:n], label_paths[:n]
30+
31+
train_images, val_images, train_labels, val_labels = train_test_split(
32+
image_paths, label_paths, test_size=1, random_state=42
33+
)
34+
35+
if split == "train":
36+
image_paths = train_images
37+
label_paths = train_labels
38+
else:
39+
image_paths = val_images
40+
label_paths = val_labels
41+
42+
return image_paths, label_paths
43+
44+
45+
def get_paths(split):
46+
image_paths, label_paths = _get_paths(split, TRAIN, LABEL)
47+
empty_image_paths, empty_label_paths = _get_paths(split, TRAIN_EMPTY, LABEL_EMPTY, n=4)
48+
return image_paths + empty_image_paths, label_paths + empty_label_paths
49+
50+
51+
def train():
52+
53+
model_name = "sgn-low-res-detection-v1"
54+
55+
train_paths, train_label_paths = get_paths("train")
56+
val_paths, val_label_paths = get_paths("val")
57+
# We need to give the paths for the test loader, although it's never used.
58+
test_paths, test_label_paths = val_paths, val_label_paths
59+
60+
print("Start training with:")
61+
print(len(train_paths), "tomograms for training")
62+
print(len(val_paths), "tomograms for validation")
63+
64+
patch_shape = [48, 256, 256]
65+
batch_size = 8
66+
check = False
67+
68+
checkpoint_path = f"./checkpoints/{model_name}"
69+
os.makedirs(checkpoint_path, exist_ok=True)
70+
with open(os.path.join(checkpoint_path, "splits.json"), "w") as f:
71+
json.dump(
72+
{
73+
"train": {"images": train_paths, "labels": train_label_paths},
74+
"val": {"images": val_paths, "labels": val_label_paths},
75+
},
76+
f, indent=2, sort_keys=True
77+
)
78+
79+
supervised_training(
80+
name=model_name,
81+
train_paths=train_paths,
82+
train_label_paths=train_label_paths,
83+
val_paths=val_paths,
84+
val_label_paths=val_label_paths,
85+
raw_key=None,
86+
patch_shape=patch_shape, batch_size=batch_size,
87+
check=check,
88+
lr=1e-4,
89+
n_iterations=int(1e5),
90+
out_channels=1,
91+
augmentations=None,
92+
eps=1e-5,
93+
sigma=4,
94+
lower_bound=None,
95+
upper_bound=None,
96+
test_paths=test_paths,
97+
test_label_paths=test_label_paths,
98+
# save_root="",
99+
dataset_class=DetectionDataset,
100+
n_samples_train=3200,
101+
n_samples_val=160,
102+
sampler=MinPointSampler(min_points=1, p_reject=0.5),
103+
)
104+
105+
106+
def main():
107+
train()
108+
109+
110+
if __name__ == "__main__":
111+
main()

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import imageio.v3 as imageio
12
import numpy as np
23
import pandas as pd
34
import torch
@@ -38,7 +39,6 @@ def __call__(self, x: np.ndarray, n_points: int) -> bool:
3839

3940
def load_labels(label_path, shape, bb):
4041
points = pd.read_csv(label_path)
41-
assert len(points.columns) == len(shape)
4242
z_coords, y_coords, x_coords = points["axis-0"].values, points["axis-1"].values, points["axis-2"].values
4343

4444
if bb is not None:
@@ -85,6 +85,25 @@ def process_labels(coords, shape, sigma, eps, bb=None):
8585
return labels
8686

8787

88+
def process_labels_hacky(coords, shape, sigma, eps, bb=None):
89+
90+
if bb:
91+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
92+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
93+
labels = np.zeros(restricted_shape, dtype="float32")
94+
shape = restricted_shape
95+
else:
96+
labels = np.zeros(shape, dtype="float32")
97+
98+
labels[coords] = 1
99+
labels = gaussian(labels, sigma)
100+
labels = labels.clip(0, 0.0075)
101+
labels /= (labels.max() + 1e-7)
102+
labels *= 4
103+
labels = labels.clip(0, 1)
104+
return labels
105+
106+
88107
class DetectionDataset(torch.utils.data.Dataset):
89108
max_sampling_attempts = 500
90109

@@ -132,8 +151,8 @@ def __init__(
132151
self.eps = eps
133152
self.sigma = sigma
134153

135-
with zarr.open(self.raw_path, "r") as f:
136-
self.shape = f[self.raw_key].shape
154+
self.raw = imageio.imread(self.raw_path) if raw_key is None else zarr.open(self.raw_path, "r")[raw_key][:]
155+
self.shape = self.raw.shape
137156

138157
if n_samples is None:
139158
self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
@@ -159,9 +178,8 @@ def _sample_bounding_box(self, shape):
159178
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
160179

161180
def _get_sample(self, index):
162-
raw, label_path = self.raw_path, self.label_path
181+
raw, label_path = self.raw, self.label_path
163182

164-
raw = zarr.open(raw)[self.raw_key]
165183
have_raw_channels = raw.ndim == 4 # 3D with channels
166184
shape = raw.shape
167185

@@ -187,7 +205,13 @@ def _get_sample(self, index):
187205
if sample_id > self.max_sampling_attempts:
188206
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
189207

190-
label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
208+
# For synapse detection.
209+
# label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
210+
211+
# For SGN detection with data specfic hacks
212+
label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
213+
gap = 6
214+
raw_patch, label = raw_patch[gap:-gap], label[gap:-gap]
191215

192216
have_label_channels = label.ndim == 4
193217
if have_label_channels:

0 commit comments

Comments
 (0)