Skip to content

Commit 55125f4

Browse files
authored
Merge pull request #9 from AllenNeuralDynamics/refactor-training
feat: foreground sampling via segmentation
2 parents 7018671 + 08952d6 commit 55125f4

File tree

2 files changed

+78
-16
lines changed

2 files changed

+78
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dynamic = ["version"]
1919
dependencies = [
2020
'bm4d',
2121
'boto3',
22+
'fastremap',
2223
'gcsfs',
2324
'google-cloud-storage',
2425
'interrogate',

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.utils.data import Dataset
1616
from tqdm import tqdm
1717

18+
import fastremap
1819
import numpy as np
1920
import random
2021
import tensorstore as ts
@@ -178,44 +179,68 @@ def sample_voxel(self, brain_id):
178179
def sample_foreground_voxel(self, brain_id):
179180
if self.skeletons[brain_id] is not None and np.random.random() > 0.5:
180181
return self.sample_skeleton_voxel(brain_id)
181-
#elif self.segmentations[brain_id] is not None:
182-
# return self.sample_segmentation_voxel(brain_id)
183-
else:
184-
return self.sample_bright_voxel(brain_id)
182+
else: # self.segmentations[brain_id] is not None:
183+
return self.sample_segmentation_voxel(brain_id)
184+
185+
#else:
186+
# return self.sample_bright_voxel(brain_id)
185187

186188
def sample_skeleton_voxel(self, brain_id):
187189
idx = random.randint(0, len(self.foreground[brain_id]) - 1)
188190
shift = np.random.randint(0, 16, size=3)
189191
return tuple(self.foreground[brain_id][idx] + shift)
190192

191193
def sample_segmentation_voxel(self, brain_id):
194+
while self.segmentations[brain_id] is None:
195+
brain_id = self.sample_brain()
196+
192197
cnt = 0
193-
while cnt < 32:
198+
best_voxel, max_volume = None, 0
199+
while max_volume < 4000:
194200
# Read random image patch
195201
voxel = self.sample_interior_voxel(brain_id)
196202
labels_patch = self.read_precomputed_patch(brain_id, voxel)
197203

198204
# Check if labels patch has large enough object
199-
# --> call fastremap
200-
# --> find largest object
201-
return voxel
205+
vals, cnts = fastremap.unique(labels_patch, return_counts=True)
206+
if len(cnts) > 1:
207+
volume = np.max(cnts[1:])
208+
if volume > max_volume:
209+
best_voxel = voxel
210+
max_volume = volume
211+
212+
# Check number of tries
213+
cnt += 1
214+
if cnt > 32:
215+
break
216+
217+
print("Brain_ID:", brain_id)
218+
print("Largest Volume:", max_volume)
219+
print("Voxel:", best_voxel)
220+
print("# Attempts:", cnt)
221+
return best_voxel
202222

203223
def sample_bright_voxel(self, brain_id):
204224
cnt = 0
205225
brightest_voxel, max_brightness = None, 0
206-
while cnt < 32:
226+
while max_brightness < self.min_brightness:
207227
# Read random image patch
208228
voxel = self.sample_interior_voxel(brain_id)
209229
img_patch = self.read_patch(brain_id, voxel)
210230

211231
# Check if image patch is bright enough
212232
brightness = np.max(img_patch)
213-
if brightness >= self.min_brightness:
214-
return voxel
215-
elif brightness > max_brightness:
233+
if brightness > max_brightness:
216234
brightest_voxel = voxel
217235
max_brightness = brightness
236+
237+
# Check number of tries
218238
cnt += 1
239+
if cnt > 32:
240+
break
241+
print("Brain_ID:", brain_id)
242+
print("Max Brightness:", max_brightness)
243+
print("# Attempts:", cnt)
219244
return brightest_voxel
220245

221246
def sample_interior_voxel(self, brain_id):
@@ -238,21 +263,57 @@ def __len__(self):
238263
return self.n_examples_per_epoch
239264

240265
def read_patch(self, brain_id, center):
266+
"""
267+
Reads an image patch from a Zarr array.
268+
269+
Parameters
270+
----------
271+
brain_id : str
272+
Unique identifier of the sampled brain.
273+
center : Tuple[int]
274+
Center of image patch to be read.
275+
276+
Returns
277+
-------
278+
numpy.ndarray
279+
Image patch.
280+
"""
241281
s = img_util.get_slices(center, self.patch_shape)
242282
return self.imgs[brain_id][(0, 0, *s)]
243283

244284
def read_precomputed_patch(self, brain_id, center):
245285
"""
246-
Reads an image patch from a precomputed array.
286+
Reads an image patch from a Precomputed array.
247287
248288
Parameters
249289
----------
250-
...
290+
brain_id : str
291+
Unique identifier of the sampled brain.
292+
center : Tuple[int]
293+
Center of image patch to be read.
294+
295+
Returns
296+
-------
297+
numpy.ndarray
298+
Image patch.
251299
"""
252300
s = img_util.get_slices(center, self.patch_shape)
253-
return self.segmentations[brain_id][(0, 0, *s)].read().result()
301+
return self.segmentations[brain_id][s].read().result()
254302

255303
def to_voxels(self, xyz_arr):
304+
"""
305+
Converts 3D points from physical to voxel coordinates.
306+
307+
Parameters
308+
----------
309+
xyz_arr : numpy.ndarray
310+
Array with shape (n, 3) that contains 3D points.
311+
312+
Returns
313+
-------
314+
numpy.ndarray
315+
3D Points converted to voxel coordinates.
316+
"""
256317
for i in range(3):
257318
xyz_arr[:, i] = xyz_arr[:, i] / self.anisotropy[i]
258319
return np.flip(xyz_arr, axis=1).astype(int)
@@ -442,7 +503,7 @@ def init_datasets(
442503
n_train_examples_per_epoch=100,
443504
n_validate_examples=0,
444505
segmentation_prefixes_path=None,
445-
sigma_bm4d=30,
506+
sigma_bm4d=16,
446507
swc_pointers=None
447508
):
448509
# Initializations

0 commit comments

Comments
 (0)