Skip to content

Commit 938d4ba

Browse files
Training is working
1 parent c065849 commit 938d4ba

File tree

4 files changed

+331
-0
lines changed

4 files changed

+331
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data/
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import numpy as np
2+
import pandas as pd
3+
import torch
4+
import zarr
5+
6+
from skimage.filters import gaussian
7+
from torch_em.util import ensure_tensor_with_channels
8+
9+
10+
# Process labels stored in json napari style.
11+
# I don't actually think that we need the epsilon here, but will leave it for now.
12+
def process_labels(label_path, shape, sigma, eps):
13+
labels = np.zeros(shape, dtype="float32")
14+
points = pd.read_csv(label_path)
15+
assert len(points.columns) == len(shape)
16+
coords = tuple(
17+
np.clip(np.round(points[ax].values).astype("int"), 0, shape[i] - 1)
18+
for i, ax in enumerate(points.columns)
19+
)
20+
labels[coords] = 1
21+
labels = gaussian(labels, sigma)
22+
# TODO better normalization?
23+
labels /= labels.max()
24+
return labels
25+
26+
27+
class DetectionDataset(torch.utils.data.Dataset):
28+
max_sampling_attempts = 500
29+
30+
def __init__(
31+
self,
32+
raw_image_paths,
33+
label_paths,
34+
patch_shape,
35+
raw_transform=None,
36+
label_transform=None,
37+
transform=None,
38+
dtype=torch.float32,
39+
label_dtype=torch.float32,
40+
n_samples=None,
41+
sampler=None,
42+
eps=1e-8,
43+
sigma=None,
44+
**kwargs,
45+
):
46+
self.raw_images = raw_image_paths
47+
# TODO make this a parameter
48+
self.raw_key = "raw"
49+
self.label_images = label_paths
50+
self._ndim = 3
51+
52+
assert len(patch_shape) == self._ndim
53+
self.patch_shape = patch_shape
54+
55+
self.raw_transform = raw_transform
56+
self.label_transform = label_transform
57+
self.transform = transform
58+
self.sampler = sampler
59+
60+
self.dtype = dtype
61+
self.label_dtype = label_dtype
62+
63+
self.eps = eps
64+
self.sigma = sigma
65+
66+
if n_samples is None:
67+
self._len = len(self.raw_images)
68+
self.sample_random_index = False
69+
else:
70+
self._len = n_samples
71+
self.sample_random_index = True
72+
73+
def __len__(self):
74+
return self._len
75+
76+
@property
77+
def ndim(self):
78+
return self._ndim
79+
80+
def _sample_bounding_box(self, shape):
81+
if any(sh < psh for sh, psh in zip(shape, self.patch_shape)):
82+
raise NotImplementedError(
83+
f"Image padding is not supported yet. Data shape {shape}, patch shape {self.patch_shape}"
84+
)
85+
bb_start = [
86+
np.random.randint(0, sh - psh) if sh - psh > 0 else 0
87+
for sh, psh in zip(shape, self.patch_shape)
88+
]
89+
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
90+
91+
def _get_sample(self, index):
92+
if self.sample_random_index:
93+
index = np.random.randint(0, len(self.raw_images))
94+
raw, label = self.raw_images[index], self.label_images[index]
95+
96+
raw = zarr.open(raw)[self.raw_key]
97+
# Note: this is quite inefficient, because we process the full crop rather than
98+
# just the requested bounding box.
99+
label = process_labels(label, raw.shape, self.sigma, self.eps)
100+
101+
have_raw_channels = raw.ndim == 4 # 3D with channels
102+
have_label_channels = label.ndim == 4
103+
if have_label_channels:
104+
raise NotImplementedError("Multi-channel labels are not supported.")
105+
106+
shape = raw.shape
107+
prefix_box = tuple()
108+
if have_raw_channels:
109+
if shape[-1] < 16:
110+
shape = shape[:-1]
111+
else:
112+
shape = shape[1:]
113+
prefix_box = (slice(None), )
114+
115+
bb = self._sample_bounding_box(shape)
116+
raw_patch = np.array(raw[prefix_box + bb])
117+
label_patch = np.array(label[bb])
118+
119+
if self.sampler is not None:
120+
sample_id = 0
121+
while not self.sampler(raw_patch, label_patch):
122+
bb = self._sample_bounding_box(shape)
123+
raw_patch = np.array(raw[prefix_box + bb])
124+
label_patch = np.array(label[bb])
125+
sample_id += 1
126+
if sample_id > self.max_sampling_attempts:
127+
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
128+
129+
if have_raw_channels and len(prefix_box) == 0:
130+
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width
131+
132+
return raw_patch, label_patch
133+
134+
def __getitem__(self, index):
135+
raw, labels = self._get_sample(index)
136+
# initial_label_dtype = labels.dtype
137+
138+
if self.raw_transform is not None:
139+
raw = self.raw_transform(raw)
140+
141+
if self.label_transform is not None:
142+
labels = self.label_transform(labels)
143+
144+
if self.transform is not None:
145+
raw, labels = self.transform(raw, labels)
146+
147+
raw = ensure_tensor_with_channels(raw, ndim=self._ndim, dtype=self.dtype)
148+
labels = ensure_tensor_with_channels(labels, ndim=self._ndim, dtype=self.label_dtype)
149+
return raw, labels
150+
151+
152+
if __name__ == "__main__":
153+
import napari
154+
155+
raw_path = "training_data/images/10.1L_mid_IHCribboncount_5_Z.zarr"
156+
label_path = "training_data/labels/10.1L_mid_IHCribboncount_5_Z.csv"
157+
158+
f = zarr.open(raw_path, "r")
159+
raw = f["raw"][:]
160+
161+
labels = process_labels(label_path, shape=raw.shape, sigma=1, eps=1e-7)
162+
163+
v = napari.Viewer()
164+
v.add_image(raw)
165+
v.add_image(labels)
166+
napari.run()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
from glob import glob
3+
from pathlib import Path
4+
5+
import h5py
6+
import napari
7+
import numpy as np
8+
import pandas as pd
9+
import zarr
10+
11+
12+
def get_voxel_size(imaris_file):
13+
with h5py.File(imaris_file, "r") as f:
14+
info = f["/DataSetInfo/Image"]
15+
ext = [[float(b"".join(info.attrs[f"ExtMin{i}"]).decode()),
16+
float(b"".join(info.attrs[f"ExtMax{i}"]).decode())] for i in range(3)]
17+
size = [int(b"".join(info.attrs[dim]).decode()) for dim in ["X", "Y", "Z"]]
18+
vsize = np.array([(max_-min_)/s for (min_, max_), s in zip(ext, size)])
19+
return vsize
20+
21+
22+
def extract_training_data(imaris_file, output_folder):
23+
with h5py.File(imaris_file, "r") as f:
24+
data = f["/DataSet/ResolutionLevel 0/TimePoint 0/Channel 0/Data"][:]
25+
points = f["/Scene/Content/Points0/CoordsXYZR"][:]
26+
points = points[:, :-1]
27+
points = points[:, ::-1]
28+
29+
# TODO crop the data to the original shape.
30+
# Can we just crop the zero-padding ?!
31+
crop_box = np.where(data != 0)
32+
crop_box = tuple(slice(0, int(cb.max() + 1)) for cb in crop_box)
33+
data = data[crop_box]
34+
print(data.shape)
35+
36+
# Scale the points to match the image dimensions.
37+
voxel_size = get_voxel_size(imaris_file)
38+
points /= voxel_size[None]
39+
40+
if output_folder is None:
41+
v = napari.Viewer()
42+
v.add_image(data)
43+
v.add_points(points)
44+
v.title = os.path.basename(imaris_file)
45+
napari.run()
46+
else:
47+
image_folder = os.path.join(output_folder, "images")
48+
os.makedirs(image_folder, exist_ok=True)
49+
50+
label_folder = os.path.join(output_folder, "labels")
51+
os.makedirs(label_folder, exist_ok=True)
52+
53+
fname = Path(imaris_file).stem
54+
image_file = os.path.join(image_folder, f"{fname}.zarr")
55+
label_file = os.path.join(label_folder, f"{fname}.csv")
56+
57+
coords = pd.DataFrame(points, columns=["axis-0", "axis-1", "axis-2"])
58+
coords.to_csv(label_file, index=False)
59+
60+
f = zarr.open(image_file, "a")
61+
f.create_dataset("raw", data=data)
62+
63+
64+
# Files that look good for training:
65+
# - 4.1L_apex_IHCribboncount_Z.ims
66+
# - 4.1L_base_IHCribbons_Z.ims
67+
# - 4.1L_mid_IHCribboncount_Z.ims
68+
# - 4.2R_apex_IHCribboncount_Z.ims
69+
# - 4.2R_apex_IHCribboncount_Z.ims
70+
# - 6.2R_apex_IHCribboncount_Z.ims (very small crop)
71+
# - 6.2R_base_IHCribbons_Z.ims
72+
def main():
73+
files = sorted(glob("./data/synapse_stains/*.ims"))
74+
for ff in files:
75+
extract_training_data(ff, output_folder="./training_data")
76+
77+
78+
if __name__ == "__main__":
79+
main()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import sys
3+
4+
from detection_dataset import DetectionDataset
5+
6+
# sys.path.append()
7+
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
8+
9+
from utils.training import supervised_training # noqa
10+
11+
TRAIN_ROOT = "./training_data/images"
12+
LABEL_ROOT = "./training_data/labels"
13+
14+
15+
def get_paths(split):
16+
file_names = [
17+
"4.1L_apex_IHCribboncount_Z",
18+
"4.1L_base_IHCribbons_Z",
19+
"4.1L_mid_IHCribboncount_Z",
20+
"4.2R_apex_IHCribboncount_Z",
21+
"4.2R_apex_IHCribboncount_Z",
22+
"6.2R_apex_IHCribboncount_Z",
23+
"6.2R_base_IHCribbons_Z",
24+
]
25+
image_paths = [os.path.join(TRAIN_ROOT, f"{fname}.zarr") for fname in file_names]
26+
label_paths = [os.path.join(LABEL_ROOT, f"{fname}.csv") for fname in file_names]
27+
28+
if split == "train":
29+
image_paths = image_paths[:-1]
30+
label_paths = label_paths[:-1]
31+
else:
32+
image_paths = image_paths[-1:]
33+
label_paths = label_paths[-1:]
34+
35+
return image_paths, label_paths
36+
37+
38+
# TODO maybe add a sampler for the label data
39+
def train():
40+
41+
model_name = "synapse_detection_v1"
42+
43+
train_paths, train_label_paths = get_paths("train")
44+
val_paths, val_label_paths = get_paths("val")
45+
# We need to give the paths for the test loader, although it's never used.
46+
test_paths, test_label_paths = val_paths, val_label_paths
47+
48+
print("Start training with:")
49+
print(len(train_paths), "tomograms for training")
50+
print(len(val_paths), "tomograms for validation")
51+
52+
patch_shape = [32, 96, 96]
53+
54+
batch_size = 8
55+
check = False
56+
57+
supervised_training(
58+
name=model_name,
59+
train_paths=train_paths,
60+
train_label_paths=train_label_paths,
61+
val_paths=val_paths,
62+
val_label_paths=val_label_paths,
63+
patch_shape=patch_shape, batch_size=batch_size,
64+
check=check,
65+
lr=1e-4,
66+
n_iterations=int(2.5e4),
67+
out_channels=1,
68+
augmentations=None,
69+
eps=1e-5,
70+
sigma=1,
71+
lower_bound=None,
72+
upper_bound=None,
73+
test_paths=test_paths,
74+
test_label_paths=test_label_paths,
75+
# save_root="",
76+
dataset_class=DetectionDataset,
77+
)
78+
79+
80+
def main():
81+
train()
82+
83+
84+
if __name__ == "__main__":
85+
main()

0 commit comments

Comments
 (0)