Skip to content

Commit 7f89607

Browse files
authored
Merge pull request #3 from AllenNeuralDynamics/refactor-update-train
refactor: updated doc
2 parents f51e2b2 + d5f115d commit 7f89607

File tree

2 files changed

+19
-15
lines changed

2 files changed

+19
-15
lines changed

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,29 @@
2626

2727
# --- Custom Datasets ---
2828
class TrainDataset(Dataset):
29+
"""
30+
A PyTorch Dataset for sampling 3D patches from whole-brain images and
31+
applying the BM4D denoising algorithm. The dataset's __getitem__ method
32+
returns both the original and denoised patches. Optionally, the patch
33+
sampling maybe biased toward foreground regions by using the voxel
34+
coordinates from SWC files that represent neuron tracings.
35+
"""
2936

3037
def __init__(
3138
self,
3239
patch_shape,
3340
anisotropy=(0.748, 0.748, 1.0),
3441
boundary_buffer=4000,
35-
foreground_sampling_rate=0.25,
42+
foreground_sampling_rate=0.5,
43+
sigma=50,
3644
):
3745
# Call parent class
3846
super(TrainDataset, self).__init__()
3947

4048
# Class attributes
4149
self.anisotropy = anisotropy
4250
self.boundary_buffer = boundary_buffer
43-
self.denoise_bm4d = BM4D()
51+
self.denoise_bm4d = BM4D(sigma=sigma)
4452
self.foreground_sampling_rate = foreground_sampling_rate
4553
self.patch_shape = patch_shape
4654
self.swc_reader = Reader()
@@ -50,11 +58,11 @@ def __init__(
5058
self.imgs = dict()
5159

5260
# --- Ingest data ---
53-
def ingest_img(self, brain_id, img_path, swc_pointer):
54-
self.foreground[brain_id] = self.ingest_swcs(swc_pointer)
61+
def ingest_brain(self, brain_id, img_path, swc_pointer):
62+
self.foreground[brain_id] = self.load_swcs(swc_pointer)
5563
self.imgs[brain_id] = img_util.read(img_path)
5664

57-
def ingest_swcs(self, swc_pointer):
65+
def load_swcs(self, swc_pointer):
5866
if swc_pointer:
5967
# Initializations
6068
swc_dicts = self.swc_reader.read(swc_pointer)
@@ -146,7 +154,7 @@ def __len__(self):
146154
"""
147155
return len(self.ids)
148156

149-
def ingest_img(self, brain_id, img_path):
157+
def ingest_brain(self, brain_id, img_path):
150158
self.imgs[brain_id] = img_util.read(img_path)
151159

152160
def ingest_example(self, brain_id, voxel):
@@ -190,10 +198,6 @@ def __init__(self, dataset, batch_size=16):
190198
Dataset to iterated over.
191199
batch_size : int
192200
Number of examples in each batch.
193-
194-
Returns
195-
-------
196-
None
197201
"""
198202
# Instance attributes
199203
self.dataset = dataset
@@ -337,8 +341,8 @@ def init_datasets(
337341
swc_pointer = None
338342

339343
# Ingest data
340-
train_dataset.ingest_img(brain_id, img_path, swc_pointer)
341-
val_dataset.ingest_img(brain_id, img_path)
344+
train_dataset.ingest_brain(brain_id, img_path, swc_pointer)
345+
val_dataset.ingest_brain(brain_id, img_path)
342346

343347
# Generate validation examples
344348
for _ in range(n_validate_examples):

src/aind_exaspim_image_compression/utils/img_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,12 @@ def get_patch(img, voxel, shape, is_center=True):
191191
numpy.ndarray
192192
Patch extracted from the given image.
193193
"""
194-
# Get image patch coordiantes
194+
# Get patch coordinates
195195
start, end = get_start_end(voxel, shape, is_center=is_center)
196196
valid_start = any([s >= 0 for s in start])
197197
valid_end = any([e < img.shape[i + 2] for i, e in enumerate(end)])
198198

199-
# Get image patch
199+
# Read patch
200200
if valid_start and valid_end:
201201
return img[
202202
0, 0, start[0]: end[0], start[1]: end[1], start[2]: end[2]
@@ -383,7 +383,7 @@ def compute_cratio(img, codec, patch_shape=(64, 64, 64)):
383383
Parameters
384384
----------
385385
img : np.ndarray
386-
Image to compute compression ratio of.
386+
Image to compute the compression ratio of.
387387
codec : blosc.Blosc
388388
Blosc codec used to compress each chunk.
389389
patch_shape : Tuple[int]

0 commit comments

Comments
 (0)