Skip to content

Commit 6a1236d

Browse files
Updates to synapse training
1 parent 6da20c5 commit 6a1236d

File tree

2 files changed

+68
-38
lines changed

2 files changed

+68
-38
lines changed

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,61 @@
99

1010
# Process labels stored in json napari style.
1111
# I don't actually think that we need the epsilon here, but will leave it for now.
12-
def process_labels(label_path, shape, sigma, eps):
13-
labels = np.zeros(shape, dtype="float32")
12+
def process_labels(label_path, shape, sigma, eps, bb=None):
1413
points = pd.read_csv(label_path)
14+
15+
if bb:
16+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
17+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
18+
labels = np.zeros(restricted_shape, dtype="float32")
19+
shape = restricted_shape
20+
else:
21+
labels = np.zeros(shape, dtype="float32")
22+
1523
assert len(points.columns) == len(shape)
24+
z_coords, y_coords, x_coords = points["axis-0"], points["axis-1"], points["axis-2"]
25+
if bb is not None:
26+
z_coords -= z_min
27+
y_coords -= y_min
28+
x_coords -= x_min
29+
mask = np.logical_and.reduce([
30+
np.logical_and(z_coords >= 0, z_coords < (z_max - z_min)),
31+
np.logical_and(y_coords >= 0, y_coords < (y_max - y_min)),
32+
np.logical_and(x_coords >= 0, x_coords < (x_max - x_min)),
33+
])
34+
z_coords, y_coords, x_coords = z_coords[mask], y_coords[mask], x_coords[mask]
35+
1636
coords = tuple(
17-
np.clip(np.round(points[ax].values).astype("int"), 0, shape[i] - 1)
18-
for i, ax in enumerate(points.columns)
37+
np.clip(np.round(coord).astype("int"), 0, coord_max - 1) for coord, coord_max in zip(
38+
(z_coords, y_coords, x_coords), shape
39+
)
1940
)
41+
2042
labels[coords] = 1
2143
labels = gaussian(labels, sigma)
2244
# TODO better normalization?
23-
labels /= labels.max()
45+
labels /= (labels.max() + 1e-7)
46+
labels *= 4
2447
return labels
2548

2649

2750
class DetectionDataset(torch.utils.data.Dataset):
2851
max_sampling_attempts = 500
2952

53+
@staticmethod
54+
def compute_len(shape, patch_shape):
55+
if patch_shape is None:
56+
return 1
57+
else:
58+
n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)]))
59+
return n_samples
60+
3061
def __init__(
3162
self,
32-
raw_image_paths,
33-
label_paths,
63+
raw_path,
64+
label_path,
3465
patch_shape,
66+
raw_key,
3567
raw_transform=None,
3668
label_transform=None,
3769
transform=None,
@@ -43,10 +75,9 @@ def __init__(
4375
sigma=None,
4476
**kwargs,
4577
):
46-
self.raw_images = raw_image_paths
47-
# TODO make this a parameter
48-
self.raw_key = "raw"
49-
self.label_images = label_paths
78+
self.raw_path = raw_path
79+
self.label_path = label_path
80+
self.raw_key = raw_key
5081
self._ndim = 3
5182

5283
assert len(patch_shape) == self._ndim
@@ -63,12 +94,13 @@ def __init__(
6394
self.eps = eps
6495
self.sigma = sigma
6596

97+
with zarr.open(self.raw_path, "r") as f:
98+
self.shape = f[self.raw_key].shape
99+
66100
if n_samples is None:
67-
self._len = len(self.raw_images)
68-
self.sample_random_index = False
101+
self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples
69102
else:
70103
self._len = n_samples
71-
self.sample_random_index = True
72104

73105
def __len__(self):
74106
return self._len
@@ -89,21 +121,19 @@ def _sample_bounding_box(self, shape):
89121
return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape))
90122

91123
def _get_sample(self, index):
92-
if self.sample_random_index:
93-
index = np.random.randint(0, len(self.raw_images))
94-
raw, label = self.raw_images[index], self.label_images[index]
124+
raw, label_path = self.raw_path, self.label_path
95125

96126
raw = zarr.open(raw)[self.raw_key]
97-
# Note: this is quite inefficient, because we process the full crop rather than
98-
# just the requested bounding box.
99-
label = process_labels(label, raw.shape, self.sigma, self.eps)
127+
shape = raw.shape
128+
129+
bb = self._sample_bounding_box(shape)
130+
label = process_labels(label_path, shape, self.sigma, self.eps, bb=bb)
100131

101132
have_raw_channels = raw.ndim == 4 # 3D with channels
102133
have_label_channels = label.ndim == 4
103134
if have_label_channels:
104135
raise NotImplementedError("Multi-channel labels are not supported.")
105136

106-
shape = raw.shape
107137
prefix_box = tuple()
108138
if have_raw_channels:
109139
if shape[-1] < 16:
@@ -112,19 +142,19 @@ def _get_sample(self, index):
112142
shape = shape[1:]
113143
prefix_box = (slice(None), )
114144

115-
bb = self._sample_bounding_box(shape)
116145
raw_patch = np.array(raw[prefix_box + bb])
117-
label_patch = np.array(label[bb])
146+
label_patch = np.array(label)
118147

119148
if self.sampler is not None:
120-
sample_id = 0
121-
while not self.sampler(raw_patch, label_patch):
122-
bb = self._sample_bounding_box(shape)
123-
raw_patch = np.array(raw[prefix_box + bb])
124-
label_patch = np.array(label[bb])
125-
sample_id += 1
126-
if sample_id > self.max_sampling_attempts:
127-
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
149+
assert False, "Sampler not implemented"
150+
# sample_id = 0
151+
# while not self.sampler(raw_patch, label_patch):
152+
# bb = self._sample_bounding_box(shape)
153+
# raw_patch = np.array(raw[prefix_box + bb])
154+
# label_patch = np.array(label[bb])
155+
# sample_id += 1
156+
# if sample_id > self.max_sampling_attempts:
157+
# raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
128158

129159
if have_raw_channels and len(prefix_box) == 0:
130160
raw_patch = raw_patch.transpose((3, 0, 1, 2)) # Channels, Depth, Height, Width

scripts/synapse_marker_detection/train_synapse_detection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
77
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
88

9-
from utils.training import supervised_training # noqa
9+
from utils.training.training import supervised_training # noqa
1010

1111
TRAIN_ROOT = "./training_data/images"
1212
LABEL_ROOT = "./training_data/labels"
@@ -49,9 +49,8 @@ def train():
4949
print(len(train_paths), "tomograms for training")
5050
print(len(val_paths), "tomograms for validation")
5151

52-
patch_shape = [32, 96, 96]
53-
54-
batch_size = 8
52+
patch_shape = [40, 112, 112]
53+
batch_size = 32
5554
check = False
5655

5756
supervised_training(
@@ -60,10 +59,11 @@ def train():
6059
train_label_paths=train_label_paths,
6160
val_paths=val_paths,
6261
val_label_paths=val_label_paths,
62+
raw_key="raw",
6363
patch_shape=patch_shape, batch_size=batch_size,
6464
check=check,
6565
lr=1e-4,
66-
n_iterations=int(2.5e4),
66+
n_iterations=int(5e4),
6767
out_channels=1,
6868
augmentations=None,
6969
eps=1e-5,
@@ -74,8 +74,8 @@ def train():
7474
test_label_paths=test_label_paths,
7575
# save_root="",
7676
dataset_class=DetectionDataset,
77-
n_samples_train=800,
78-
n_samples_val=80,
77+
n_samples_train=3200,
78+
n_samples_val=160,
7979
)
8080

8181

0 commit comments

Comments
 (0)