Skip to content

Commit df9b23d

Browse files
Updates to validation, synapse detection training and low res export impl
1 parent aefeba1 commit df9b23d

File tree

10 files changed

+447
-48
lines changed

10 files changed

+447
-48
lines changed

scripts/export_lower_resolution.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import pandas as pd
6+
import tifffile
7+
import zarr
8+
9+
from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
10+
from skimage.segmentation import relabel_sequential
11+
12+
13+
def filter_component(fs, segmentation, cochlea, seg_name):
14+
# First, we download the MoBIE table for this segmentation.
15+
internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv")
16+
with fs.open(internal_path, "r") as f:
17+
table = pd.read_csv(f, sep="\t")
18+
19+
# Then we get the ids for the components and us them to filter the segmentation.
20+
component_mask = np.isin(table.component_labels.values, [1])
21+
keep_label_ids = table.label_id.values[component_mask].astype("int64")
22+
filter_mask = ~np.isin(segmentation, keep_label_ids)
23+
segmentation[filter_mask] = 0
24+
25+
segmentation, _, _ = relabel_sequential(segmentation)
26+
return segmentation
27+
28+
29+
def export_lower_resolution(args):
30+
output_folder = os.path.join(args.output_folder, args.cochlea, f"scale{args.scale}")
31+
os.makedirs(output_folder, exist_ok=True)
32+
33+
input_key = f"s{args.scale}"
34+
for channel in args.channels:
35+
out_path = os.path.join(output_folder, f"{channel}.tif")
36+
if os.path.exists(out_path):
37+
continue
38+
39+
print("Exporting channel", channel)
40+
internal_path = os.path.join(args.cochlea, "images", "ome-zarr", f"{channel}.ome.zarr")
41+
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
42+
with zarr.open(s3_store, mode="r") as f:
43+
data = f[input_key][:]
44+
print(data.shape)
45+
if args.filter_by_component:
46+
data = filter_component(fs, data, args.cochlea, channel)
47+
tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib")
48+
49+
50+
def main():
51+
parser = argparse.ArgumentParser()
52+
parser.add_argument("--cochlea", "-c", required=True)
53+
parser.add_argument("--scale", "-s", type=int, required=True)
54+
parser.add_argument("--output_folder", "-o", required=True)
55+
parser.add_argument("--channels", nargs="+", default=["PV", "VGlut3", "CTBP2"])
56+
parser.add_argument("--filter_by_component", action="store_true")
57+
args = parser.parse_args()
58+
59+
export_lower_resolution(args)
60+
61+
62+
if __name__ == "__main__":
63+
main()

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,42 @@
77
from torch_em.util import ensure_tensor_with_channels
88

99

10-
# Process labels stored in json napari style.
11-
# I don't actually think that we need the epsilon here, but will leave it for now.
12-
def process_labels(label_path, shape, sigma, eps, bb=None):
13-
points = pd.read_csv(label_path)
10+
class MinPointSampler:
11+
"""A sampler to reject samples with a low fraction of foreground pixels in the labels.
12+
13+
Args:
14+
min_fraction: The minimal fraction of foreground pixels for accepting a sample.
15+
background_id: The id of the background label.
16+
p_reject: The probability for rejecting a sample that does not meet the criterion.
17+
"""
18+
def __init__(self, min_points: int, p_reject: float = 1.0):
19+
self.min_points = min_points
20+
self.p_reject = p_reject
21+
22+
def __call__(self, x: np.ndarray, n_points: int) -> bool:
23+
"""Check the sample.
24+
25+
Args:
26+
x: The raw data.
27+
y: The label data.
28+
29+
Returns:
30+
Whether to accept this sample.
31+
"""
32+
33+
if n_points > self.min_points:
34+
return True
35+
else:
36+
return np.random.rand() > self.p_reject
1437

15-
if bb:
16-
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
17-
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
18-
labels = np.zeros(restricted_shape, dtype="float32")
19-
shape = restricted_shape
20-
else:
21-
labels = np.zeros(shape, dtype="float32")
2238

39+
def load_labels(label_path, shape, bb):
40+
points = pd.read_csv(label_path)
2341
assert len(points.columns) == len(shape)
24-
z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"]
42+
z_coords, y_coords, x_coords = points["axis-0"].values, points["axis-1"].values, points["axis-2"].values
43+
2544
if bb is not None:
45+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
2646
z_coords -= z_min
2747
y_coords -= y_min
2848
x_coords -= x_min
@@ -32,13 +52,31 @@ def process_labels(label_path, shape, sigma, eps, bb=None):
3252
np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)),
3353
])
3454
z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask]
55+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
56+
shape = restricted_shape
3557

58+
n_points = len(z_coords)
3659
coords = tuple(
3760
np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip(
3861
(z_coords, y_coords, x_coords), shape
3962
)
4063
)
4164

65+
return coords, n_points
66+
67+
68+
# Process labels stored in json napari style.
69+
# I don't actually think that we need the epsilon here, but will leave it for now.
70+
def process_labels(coords, shape, sigma, eps, bb=None):
71+
72+
if bb:
73+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
74+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
75+
labels = np.zeros(restricted_shape, dtype="float32")
76+
shape = restricted_shape
77+
else:
78+
labels = np.zeros(shape, dtype="float32")
79+
4280
labels[coords] = 1
4381
labels = gaussian(labels, sigma)
4482
# TODO better normalization?
@@ -124,16 +162,10 @@ def _get_sample(self, index):
124162
raw, label_path = self.raw_path, self.label_path
125163

126164
raw = zarr.open(raw)[self.raw_key]
165+
have_raw_channels = raw.ndim == 4 # 3D with channels
127166
shape = raw.shape
128167

129168
bb = self._sample_bounding_box(shape)
130-
label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb)
131-
132-
have_raw_channels = raw.ndim == 4 # 3D with channels
133-
have_label_channels = label.ndim == 4
134-
if have_label_channels:
135-
raise NotImplementedError("Multi-channel labels are not supported.")
136-
137169
prefix_box = tuple()
138170
if have_raw_channels:
139171
if shape[-1] < 16:
@@ -143,18 +175,25 @@ def _get_sample(self, index):
143175
prefix_box = (slice(None), )
144176

145177
raw_patch = np.array(raw[prefix_box + bb])
146-
label_patch = np.array(label)
147178

179+
coords, n_points = load_labels(label_path, shape, bb)
148180
if self.sampler is not None:
149-
assert False, "Sampler not implemented"
150-
# sample_id = 0
151-
# while not self.sampler(raw_patch, label_patch):
152-
# bb = self._sample_bounding_box(shape)
153-
# raw_patch = np.array(raw[prefix_box + bb])
154-
# label_patch = np.array(label[bb])
155-
# sample_id += 1
156-
# if sample_id > self.max_sampling_attempts:
157-
# raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
181+
sample_id = 0
182+
while not self.sampler(raw_patch, n_points):
183+
bb = self._sample_bounding_box(shape)
184+
raw_patch = np.array(raw[prefix_box + bb])
185+
coords, n_points = load_labels(label_path, shape, bb)
186+
sample_id += 1
187+
if sample_id > self.max_sampling_attempts:
188+
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
189+
190+
label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
191+
192+
have_label_channels = label.ndim == 4
193+
if have_label_channels:
194+
raise NotImplementedError("Multi-channel labels are not supported.")
195+
196+
label_patch = np.array(label)
158197

159198
if have_raw_channels and len(prefix_box) == 0:
160199
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width

scripts/synapse_marker_detection/extract_training_data.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,30 @@ def get_voxel_size(imaris_file):
1919
return vsize
2020

2121

22-
def extract_training_data(imaris_file, output_folder):
22+
def extract_training_data(imaris_file, output_folder, crop=True, scale=True):
23+
point_key = "/Scene/Content/Points0/CoordsXYZR"
2324
with h5py.File(imaris_file, "r") as f:
25+
if point_key not in f:
26+
print("Skipping", imaris_file, "due to missing annotations")
27+
return
2428
data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:]
25-
points = f["/Scene/Content/Points0/CoordsXYZR"][:]
29+
points = f[point_key][:]
2630
points = points[:, :-1]
2731
points = points[:, ::-1]
2832

2933
# TODO crop the data to the original shape.
3034
# Can we just crop the zero-padding ?!
31-
crop_box = np.where(data != 0)
32-
crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box)
33-
data = data[crop_box]
34-
print(data.shape)
35+
if crop:
36+
crop_box = np.where(data != 0)
37+
crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box)
38+
data = data[crop_box]
3539

3640
# Scale the points to match the image dimensions.
3741
voxel_size = get_voxel_size(imaris_file)
38-
points /= voxel_size[None]
42+
if scale:
43+
points /= voxel_size[None]
44+
45+
print(data.shape, voxel_size)
3946

4047
if output_folder is None:
4148
v = napari.Viewer()
@@ -69,11 +76,51 @@ def extract_training_data(imaris_file, output_folder):
6976
# - 4.2R_apex_IHCribboncount_Z.ims
7077
# - 6.2R_apex_IHCribboncount_Z.ims (very small crop)
7178
# - 6.2R_base_IHCribbons_Z.ims
72-
def main():
79+
def process_training_data_v1():
7380
files = sorted(glob("./data/synapse_stains/*.ims"))
7481
for ff in files:
7582
extract_training_data(ff, output_folder="./training_data")
7683

7784

85+
def process_training_data_v2(visualize=True):
86+
input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ImageCropsIHC_synapses"
87+
88+
train_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v2" # noqa
89+
test_output = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test/v2" # noqa
90+
91+
train_folders = ["M78L_IHC-synapse_crops"]
92+
test_folders = ["M226L_IHC-synapse_crops", "M226R_IHC-synapsecrops"]
93+
94+
valid_files = [
95+
"m78l_apexp2718_cr-ctbp2.ims",
96+
"m226r_apex_p1268_pv-ctbp2.ims",
97+
"m226r_base_p800_vglut3-ctbp2.ims",
98+
]
99+
100+
for folder in train_folders + test_folders:
101+
102+
if visualize:
103+
output_folder = None
104+
elif folder in train_folders:
105+
output_folder = train_output
106+
os.makedirs(output_folder, exist_ok=True)
107+
else:
108+
output_folder = test_output
109+
os.makedirs(output_folder, exist_ok=True)
110+
111+
imaris_files = sorted(glob(os.path.join(input_root, folder, "*.ims")))
112+
for imaris_file in imaris_files:
113+
fname = os.path.basename(imaris_file)
114+
if fname not in valid_files:
115+
continue
116+
print(fname)
117+
extract_training_data(imaris_file, output_folder, crop=True, scale=True)
118+
119+
120+
def main():
121+
# process_training_data_v1()
122+
process_training_data_v2(visualize=False)
123+
124+
78125
if __name__ == "__main__":
79126
main()

scripts/synapse_marker_detection/train_synapse_detection.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
22
import sys
33

4-
from detection_dataset import DetectionDataset
4+
from detection_dataset import DetectionDataset, MinPointSampler
55

66
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
77
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
88

99
from utils.training.training import supervised_training # noqa
1010

11-
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v1" # noqa
11+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v2" # noqa
1212
TRAIN_ROOT = os.path.join(ROOT, "images")
1313
LABEL_ROOT = os.path.join(ROOT, "labels")
1414

@@ -21,6 +21,7 @@ def get_paths(split):
2121
"4.2R_apex_IHCribboncount_Z",
2222
"4.2R_apex_IHCribboncount_Z",
2323
"6.2R_apex_IHCribboncount_Z",
24+
"m78l_apexp2718_cr-ctbp2",
2425
"6.2R_base_IHCribbons_Z",
2526
]
2627
image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names]
@@ -33,13 +34,16 @@ def get_paths(split):
3334
image_paths = image_paths[-1:]
3435
label_paths = label_paths[-1:]
3536

37+
for path in image_paths:
38+
assert os.path.exists(path), path
39+
3640
return image_paths, label_paths
3741

3842

3943
# TODO maybe add a sampler for the label data
4044
def train():
4145

42-
model_name = "synapse_detection_v1"
46+
model_name = "synapse_detection_v2"
4347

4448
train_paths, train_label_paths = get_paths("train")
4549
val_paths, val_label_paths = get_paths("val")
@@ -52,7 +56,7 @@ def train():
5256

5357
patch_shape = [40, 112, 112]
5458
batch_size = 32
55-
check = False
59+
check = True
5660

5761
supervised_training(
5862
name=model_name,
@@ -64,7 +68,7 @@ def train():
6468
patch_shape=patch_shape, batch_size=batch_size,
6569
check=check,
6670
lr=1e-4,
67-
n_iterations=int(5e4),
71+
n_iterations=int(1e5),
6872
out_channels=1,
6973
augmentations=None,
7074
eps=1e-5,
@@ -77,6 +81,7 @@ def train():
7781
dataset_class=DetectionDataset,
7882
n_samples_train=3200,
7983
n_samples_val=160,
84+
sampler=MinPointSampler(min_points=1, p_reject=0.6),
8085
)
8186

8287

scripts/validation/IHCs/run_evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pandas as pd
55
from flamingo_tools.validation import (
6-
fetch_data_for_evaluation, parse_annotation_path, compute_scores_for_annotated_slice
6+
fetch_data_for_evaluation, _parse_annotation_path, compute_scores_for_annotated_slice
77
)
88

99
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/AnnotatedImageCrops/F1ValidationIHCs"
@@ -29,7 +29,7 @@ def run_evaluation(root, annotation_folders, result_file, cache_folder):
2929
annotations = sorted(glob(os.path.join(root, folder, "*.csv")))
3030
for annotation_path in annotations:
3131
print(annotation_path)
32-
cochlea, slice_id = parse_annotation_path(annotation_path)
32+
cochlea, slice_id = _parse_annotation_path(annotation_path)
3333

3434
# For the cochlea M_LR_000226_R the actual component is 2, not 1
3535
component = 2 if "226_R" in cochlea else 1

0 commit comments

Comments
 (0)