Skip to content

Commit 7018671

Browse files
authored
Merge pull request #8 from AllenNeuralDynamics/refactor-training
feat: brightness biased sampling
2 parents 420779d + 690a609 commit 7018671

File tree

4 files changed

+216
-56
lines changed

4 files changed

+216
-56
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
'scikit-learn',
3030
'scipy',
3131
'tensorboard',
32+
'tensorstore',
3233
'torch',
3334
'torchvision',
3435
'tqdm',

src/aind_exaspim_image_compression/inference.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def predict_patch(patch, model, normalization_percentiles=[5, 99.9]):
123123
Denoised 3D patch with the same shape as input patch.
124124
"""
125125
# Run model
126-
assert len(normalization_percentiles) == 2, "Must provide two percentiles"
126+
assert len(normalization_percentiles) == 2, "Must provide two percentiles"
127127
mn, mx = np.percentile(patch, normalization_percentiles)
128128
patch = to_tensor((patch - mn) / max(mx, 1))
129129
with torch.no_grad():
@@ -134,13 +134,20 @@ def predict_patch(patch, model, normalization_percentiles=[5, 99.9]):
134134
return np.abs(pred[0, 0, ...] * mx + mn).astype(np.uint16)
135135

136136

137-
def _predict_batch(img, model, starts, patch_size, trim=5):
137+
def _predict_batch(
138+
img,
139+
model,
140+
starts,
141+
patch_size,
142+
normalization_percentiles=[5, 99.9],
143+
trim=5,
144+
):
138145
# Subroutine
139146
def read_patch(i):
140147
start = starts[i]
141148
end = [min(s + patch_size, d) for s, d in zip(start, (D, H, W))]
142149
patch = img[0, 0, start[0]:end[0], start[1]:end[1], start[2]:end[2]]
143-
mn, mx = np.percentile(patch, [5, 99.9])
150+
mn, mx = np.percentile(patch, normalization_percentiles)
144151
patch = add_padding((patch - mn) / max(mx, 1), patch_size)
145152
return i, patch.astype(np.float32), (mn, mx)
146153

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 174 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import random
20+
import tensorstore as ts
2021
import torch
2122

2223
from aind_exaspim_image_compression.utils import img_util, util
@@ -37,10 +38,12 @@ def __init__(
3738
self,
3839
patch_shape,
3940
anisotropy=(0.748, 0.748, 1.0),
40-
boundary_buffer=4000,
41+
boundary_buffer=5000,
4142
foreground_sampling_rate=0.2,
42-
n_examples_per_epoch=200,
43-
sigma_bm4d=50,
43+
min_brightness=200,
44+
n_examples_per_epoch=300,
45+
normalization_percentiles=[0.5, 99.9],
46+
sigma_bm4d=30,
4447
):
4548
# Call parent class
4649
super(TrainDataset, self).__init__()
@@ -49,51 +52,110 @@ def __init__(
4952
self.anisotropy = anisotropy
5053
self.boundary_buffer = boundary_buffer
5154
self.foreground_sampling_rate = foreground_sampling_rate
55+
self.min_brightness = min_brightness
5256
self.n_examples_per_epoch = n_examples_per_epoch
57+
self.normalization_percentiles = normalization_percentiles
5358
self.patch_shape = patch_shape
5459
self.sigma_bm4d = sigma_bm4d
5560
self.swc_reader = Reader()
5661

5762
# Data structures
58-
self.foreground = dict()
63+
self.segmentations = dict()
64+
self.skeletons = dict()
5965
self.imgs = dict()
6066

6167
# --- Ingest data ---
62-
def ingest_brain(self, brain_id, img_path, swc_pointer):
63-
self.foreground[brain_id] = self.load_swcs(swc_pointer)
68+
def ingest_brain(self, brain_id, img_path, segmentation_path, swc_pointer):
69+
"""
70+
Loads a brain image, label mask, and skeletons, then stores each in
71+
internal dictionaries.
72+
73+
Parameters
74+
----------
75+
brain_id : hashable
76+
Unique identifier for the brain corresponding to the image.
77+
img_path : str or Path
78+
Path to whole-brain image to be read.
79+
segmentation_path : str
80+
Path to segmentation.
81+
swc_path : str
82+
Path to SWC files.
83+
"""
84+
self.segmentations[brain_id] = self.load_segmentation(segmentation_path)
6485
self.imgs[brain_id] = img_util.read(img_path)
86+
self.skeletons[brain_id] = self.load_swcs(swc_pointer)
87+
88+
def load_segmentation(self, segmentation_path):
89+
"""
90+
Reads a segmentation mask generated by Google Applied Sciences (GAS).
91+
92+
Parameters
93+
----------
94+
segmentation_path : str
95+
Path to segmentation.
96+
97+
Returns
98+
-------
99+
...
100+
"""
101+
if segmentation_path:
102+
# Load image
103+
label_mask = ts.open(
104+
{
105+
"driver": "neuroglancer_precomputed",
106+
"kvstore": {
107+
"driver": "gcs",
108+
"bucket": "allen-nd-goog",
109+
"path": segmentation_path,
110+
},
111+
"context": {
112+
"cache_pool": {"total_bytes_limit": 1000000000},
113+
"cache_pool#remote": {"total_bytes_limit": 1000000000},
114+
"data_copy_concurrency": {"limit": 8},
115+
},
116+
"recheck_cached_data": "open",
117+
}
118+
).result()
119+
120+
# Permute axes to be consistent with raw image.
121+
label_mask = label_mask[ts.d["channel"][0]]
122+
label_mask = label_mask[ts.d[0].transpose[2]]
123+
label_mask = label_mask[ts.d[0].transpose[1]]
124+
return label_mask
125+
else:
126+
return None
65127

66128
def load_swcs(self, swc_pointer):
67129
if swc_pointer:
68130
# Initializations
69131
swc_dicts = self.swc_reader.read(swc_pointer)
70132
n_points = np.sum([len(d["xyz"]) for d in swc_dicts])
71133

72-
# Extract foreground voxels
134+
# Extract skeleton voxels
73135
if n_points > 0:
74136
start = 0
75-
foreground = np.zeros((n_points, 3), dtype=np.int32)
137+
skeletons = np.zeros((n_points, 3), dtype=np.int32)
76138
for swc_dict in swc_dicts:
77139
end = start + len(swc_dict["xyz"])
78-
foreground[start:end] = self.to_voxels(swc_dict["xyz"])
140+
skeletons[start:end] = self.to_voxels(swc_dict["xyz"])
79141
start = end
80-
return foreground
81-
return set()
142+
return skeletons
143+
return None
82144

83145
# --- Core Routines ---
84146
def __getitem__(self, dummy_input):
85147
# Sample image patch
86148
brain_id = self.sample_brain()
87149
voxel = self.sample_voxel(brain_id)
88-
noise = self.get_patch(brain_id, voxel)
89-
mn, mx = np.percentile(noise, 5), np.percentile(noise, 99.9)
150+
noise = self.read_patch(brain_id, voxel)
151+
mn, mx = np.percentile(noise, self.normalization_percentiles)
90152

91153
# Denoise image patch
92154
denoised = bm4d(noise, self.sigma_bm4d)
93155

94156
# Normalize image patches
95-
noise = (noise - mn) / max(mx, 1)
96-
denoised = (denoised - mn) / max(mx, 1)
157+
noise = (noise - mn) / max(mx - mn, 1)
158+
denoised = (denoised - mn) / max(mx - mn, 1)
97159
return noise, denoised, (mn, mx)
98160

99161
def sample_brain(self):
@@ -114,12 +176,47 @@ def sample_voxel(self, brain_id):
114176
return self.sample_interior_voxel(brain_id)
115177

116178
def sample_foreground_voxel(self, brain_id):
117-
if len(self.foreground[brain_id]) > 0:
118-
idx = random.randint(0, len(self.foreground[brain_id]) - 1)
119-
shift = np.random.randint(0, 16, size=3)
120-
return tuple(self.foreground[brain_id][idx] + shift)
179+
if self.skeletons[brain_id] is not None and np.random.random() > 0.5:
180+
return self.sample_skeleton_voxel(brain_id)
181+
#elif self.segmentations[brain_id] is not None:
182+
# return self.sample_segmentation_voxel(brain_id)
121183
else:
122-
return self.sample_interior_voxel(brain_id)
184+
return self.sample_bright_voxel(brain_id)
185+
186+
def sample_skeleton_voxel(self, brain_id):
187+
idx = random.randint(0, len(self.foreground[brain_id]) - 1)
188+
shift = np.random.randint(0, 16, size=3)
189+
return tuple(self.foreground[brain_id][idx] + shift)
190+
191+
def sample_segmentation_voxel(self, brain_id):
192+
cnt = 0
193+
while cnt < 32:
194+
# Read random image patch
195+
voxel = self.sample_interior_voxel(brain_id)
196+
labels_patch = self.read_precomputed_patch(brain_id, voxel)
197+
198+
# Check if labels patch has large enough object
199+
# --> call fastremap
200+
# --> find largest object
201+
return voxel
202+
203+
def sample_bright_voxel(self, brain_id):
204+
cnt = 0
205+
brightest_voxel, max_brightness = None, 0
206+
while cnt < 32:
207+
# Read random image patch
208+
voxel = self.sample_interior_voxel(brain_id)
209+
img_patch = self.read_patch(brain_id, voxel)
210+
211+
# Check if image patch is bright enough
212+
brightness = np.max(img_patch)
213+
if brightness >= self.min_brightness:
214+
return voxel
215+
elif brightness > max_brightness:
216+
brightest_voxel = voxel
217+
max_brightness = brightness
218+
cnt += 1
219+
return brightest_voxel
123220

124221
def sample_interior_voxel(self, brain_id):
125222
voxel = list()
@@ -140,36 +237,54 @@ def __len__(self):
140237
"""
141238
return self.n_examples_per_epoch
142239

143-
def get_patch(self, brain_id, voxel):
144-
s, e = img_util.get_start_end(voxel, self.patch_shape)
145-
return self.imgs[brain_id][0, 0, s[0]: e[0], s[1]: e[1], s[2]: e[2]]
240+
def read_patch(self, brain_id, center):
241+
s = img_util.get_slices(center, self.patch_shape)
242+
return self.imgs[brain_id][(0, 0, *s)]
243+
244+
def read_precomputed_patch(self, brain_id, center):
245+
"""
246+
Reads an image patch from a precomputed array.
247+
248+
Parameters
249+
----------
250+
...
251+
"""
252+
s = img_util.get_slices(center, self.patch_shape)
253+
return self.segmentations[brain_id][(0, 0, *s)].read().result()
146254

147255
def to_voxels(self, xyz_arr):
148256
for i in range(3):
149257
xyz_arr[:, i] = xyz_arr[:, i] / self.anisotropy[i]
150258
return np.flip(xyz_arr, axis=1).astype(int)
151259

152-
def update_foreground_sampling_rate(self, foreground_sampling_rate):
153-
self.foreground_sampling_rate = foreground_sampling_rate
154-
155260

156261
class ValidateDataset(Dataset):
157262

158-
def __init__(self, patch_shape, sigma_bm4d=50):
263+
def __init__(
264+
self,
265+
patch_shape,
266+
normalization_percentiles=[0.5, 99.9],
267+
sigma_bm4d=30,
268+
):
159269
"""
160270
Instantiates a ValidateDataset object.
161271
162272
Parameters
163273
----------
164274
patch_shape : Tuple[int]
165275
Shape of image patches to be extracted.
166-
sigma_bm4d : float
167-
Smoothing parameter used in the BM4D denoising algorithm.
276+
normalization_percentiles : List[float], optional
277+
Upper and lower percentiles used to normalize the input image.
278+
Default is [0.5, 99.5].
279+
sigma_bm4d : float, optional
280+
Smoothing parameter used in the BM4D denoising algorithm. Default
281+
is 30.
168282
"""
169283
# Call parent class
170284
super(ValidateDataset, self).__init__()
171285

172286
# Instance attributes
287+
self.normalization_percentiles = normalization_percentiles
173288
self.patch_shape = patch_shape
174289
self.sigma_bm4d = sigma_bm4d
175290

@@ -217,13 +332,13 @@ def ingest_example(self, brain_id, voxel):
217332
Voxel coordinates of the patch center in the brain volume.
218333
"""
219334
# Get image patches
220-
noise = self.get_patch(brain_id, voxel)
221-
mn, mx = np.percentile(noise, 5), np.percentile(noise, 99.9)
335+
noise = self.read_patch(brain_id, voxel)
336+
mn, mx = np.percentile(noise, self.normalization_percentiles)
222337
denoised = bm4d(noise, self.sigma_bm4d)
223338

224339
# Normalize image patches
225-
noise = (noise - mn) / max(mx, 1)
226-
denoised = (denoised - mn) / max(mx, 1)
340+
noise = (noise - mn) / max(mx - mn, 1)
341+
denoised = (denoised - mn) / max(mx - mn, 1)
227342

228343
# Store results
229344
self.example_ids.append((brain_id, voxel))
@@ -251,9 +366,9 @@ def __getitem__(self, idx):
251366
"""
252367
return self.noise[idx], self.denoised[idx], self.mn_mxs[idx]
253368

254-
def get_patch(self, brain_id, voxel):
255-
s, e = img_util.get_start_end(voxel, self.patch_shape)
256-
return self.imgs[brain_id][0, 0, s[0]: e[0], s[1]: e[1], s[2]: e[2]]
369+
def read_patch(self, brain_id, center):
370+
slices = img_util.get_slices(center, self.patch_shape)
371+
return self.imgs[brain_id][(0, 0, *slices)]
257372

258373

259374
# --- Custom Dataloader ---
@@ -326,8 +441,9 @@ def init_datasets(
326441
foreground_sampling_rate=0.5,
327442
n_train_examples_per_epoch=100,
328443
n_validate_examples=0,
329-
sigma_bm4d=50,
330-
swc_dict=None
444+
segmentation_prefixes_path=None,
445+
sigma_bm4d=30,
446+
swc_pointers=None
331447
):
332448
# Initializations
333449
train_dataset = TrainDataset(
@@ -338,19 +454,35 @@ def init_datasets(
338454
)
339455
val_dataset = ValidateDataset(patch_shape)
340456

457+
# Read segmentation path lookup (if applicable)
458+
if segmentation_prefixes_path:
459+
segmentation_paths = util.read_json(segmentation_prefixes_path)
460+
else:
461+
segmentation_paths = dict()
462+
341463
# Load data
342464
for brain_id in tqdm(brain_ids, desc="Load Data"):
343-
# Set paths
465+
# Set image path
344466
img_path = get_img_prefix(brain_id, img_paths_json) + str(0)
345-
if swc_dict:
346-
swc_pointer = deepcopy(swc_dict)
467+
468+
# Set segmentation path
469+
if brain_id in segmentation_paths:
470+
segmentation_path = segmentation_paths[brain_id]
471+
else:
472+
segmentation_path = None
473+
474+
# Set SWC pointer
475+
if swc_pointers:
476+
swc_pointer = deepcopy(swc_pointers)
347477
swc_pointer["path"] += f"/{brain_id}/world"
348478
else:
349479
swc_pointer = None
350480

351481
# Ingest data
352-
train_dataset.ingest_brain(brain_id, img_path, swc_pointer)
353482
val_dataset.ingest_brain(brain_id, img_path)
483+
train_dataset.ingest_brain(
484+
brain_id, img_path, segmentation_path, swc_pointer
485+
)
354486

355487
# Generate validation examples
356488
for _ in range(n_validate_examples):

0 commit comments

Comments
 (0)