Skip to content

Commit 0b7e4b2

Browse files
Update sgn detection training
1 parent 920230b commit 0b7e4b2

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

scripts/la-vision/train_sgn_detection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from utils.training.training import supervised_training # noqa
1313
from detection_dataset import DetectionDataset, MinPointSampler # noqa
1414

15-
# ROOT = "./la-vision-sgn-new" # noqa
16-
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection"
15+
ROOT = "./la-vision-sgn-new/train/sgn-detection" # noqa
16+
# ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection"
1717

1818
TRAIN = os.path.join(ROOT, "images")
1919
TRAIN_EMPTY = os.path.join(ROOT, "empty_images")
@@ -79,7 +79,7 @@ def train():
7979
)
8080

8181
# For marmoset model
82-
sigma = (0.6, 3, 3)
82+
sigma = (1.6, 3, 3)
8383
# For mouse model
8484
# sigma = (1, 4, 4)
8585
supervised_training(

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,40 @@ def process_labels_hacky(coords, shape, sigma, eps, bb=None):
104104
return labels
105105

106106

107+
def process_labels_stamped(coords, shape, sigma, eps, bb):
108+
offset = 7
109+
110+
if bb:
111+
(z_min, z_max), (y_min, y_max), (x_min, x_max) = [(s.start, s.stop) for s in bb]
112+
restricted_shape = (z_max - z_min, y_max - y_min, x_max - x_min)
113+
full_shape = tuple(sh + 2 * offset for sh in restricted_shape)
114+
labels = np.zeros(full_shape, dtype="float32")
115+
shape = restricted_shape
116+
else:
117+
full_shape = tuple(sh + 2 * offset for sh in shape)
118+
119+
labels = np.zeros(full_shape, dtype="float32")
120+
121+
z, y, x = coords
122+
if len(z) == 0:
123+
return np.zeros(shape, dtype="float32")
124+
coordinates = np.concatenate([z[:, None], y[:, None], x[:, None]], axis=1)
125+
126+
stamp = np.zeros((2*offset, 2*offset, 2*offset), dtype="float32")
127+
stamp[offset - 1, offset - 1, offset - 1] = 1
128+
stamp = gaussian(stamp, sigma=sigma)
129+
stamp /= stamp.max()
130+
131+
for coord in coordinates:
132+
bb = tuple(slice(co + offset - offset, co + offset + offset) for co in coord)
133+
val = np.maximum(labels[bb], stamp)
134+
labels[bb] = val
135+
136+
labels = labels[offset:-offset, offset:-offset, offset:-offset]
137+
assert labels.shape == shape
138+
return labels
139+
140+
107141
class DetectionDataset(torch.utils.data.Dataset):
108142
max_sampling_attempts = 500
109143

@@ -209,7 +243,8 @@ def _get_sample(self, index):
209243
# label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
210244

211245
# For SGN detection with data specfic hacks
212-
label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
246+
# label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
247+
label = process_labels_stamped(coords, shape, self.sigma, self.eps, bb=bb)
213248
# Having this halo actually makes sense in general!
214249
gap = 8
215250
gap_bb = np.s_[gap:-gap, gap:-gap, gap:-gap]

0 commit comments

Comments
 (0)