Skip to content

Commit 0e4e840

Browse files
author
Anna Grim
committed
refactor: simplified repo
1 parent 80e987a commit 0e4e840

File tree

3 files changed

+63
-126
lines changed

3 files changed

+63
-126
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ dependencies = [
3232
'torchvision',
3333
'tqdm',
3434
'xarray_multiscale',
35-
'zarr'
35+
'zarr',
36+
"aind-exaspim-image-utils @ git+https://github.com/AllenNeuralDynamics/aind-exaspim-dataset-utils.git@main"
3637
]
3738

3839
[project.optional-dependencies]

src/aind_exaspim_image_compression/inference.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,13 @@
44
@author: Anna Grim
55
@email: anna.grim@alleninstitute.org
66
7-
Denoising routines for 3D microscopy images using patch-based deep learning
8-
inference. Includes functions to extract overlapping patches, normalize and
9-
batch process them through a model on GPU, and stitch denoised patches back
10-
into a full 3D volume.
7+
Code for using BM4D-Net to denoise 3D micrscopy images. Includes routines to
8+
extract overlapping patches, normalize and batch process them through a model
9+
on GPU, and stitch denoised patches back into a full 3D volume.
1110
1211
"""
1312

14-
from concurrent.futures import (
15-
ThreadPoolExecutor,
16-
as_completed,
17-
)
18-
from numcodecs import blosc
13+
from concurrent.futures import ThreadPoolExecutor, as_completed
1914
from tqdm import tqdm
2015

2116
import itertools
@@ -37,8 +32,8 @@ def predict(
3732
verbose=True
3833
):
3934
"""
40-
Denoises a 3D image by processing patches in batches and running deep
41-
learning model.
35+
Denoises a 3D image by tiling it into overlapping patches, forming batches
36+
of patches, and processing each batch through the given model.
4237
4338
Parameters
4439
----------
@@ -52,8 +47,11 @@ def predict(
5247
Size of the cubic patch extracted from the image. Default is 64.
5348
overlap : int, optional
5449
Number of voxels to overlap between patches. Default is 16.
50+
trim : int, optional
51+
Number of voxels from the image boundary that are set to zero to
52+
suppress noisy edge predictions. Default is 5.
5553
verbose : bool, optional
56-
Whether to show a tqdm progress bar. Default is True.
54+
Whether to show a progress bar. Default is True.
5755
5856
Returns
5957
-------
@@ -65,16 +63,16 @@ def predict(
6563
img = img[np.newaxis, ...]
6664

6765
# Initializations
68-
starts_generator = generate_patch_starts(img, patch_size, overlap)
66+
patch_starts_generator = generate_patch_starts(img, patch_size, overlap)
6967
n_starts = count_patches(img, patch_size, overlap)
7068
if denoised is None:
7169
denoised = np.zeros_like(img, dtype=np.uint16)
7270

7371
# Main
7472
pbar = tqdm(total=n_starts, desc="Denoise") if verbose else None
7573
for i in range(0, n_starts, batch_size):
76-
# Run model
77-
starts = list(itertools.islice(starts_generator, batch_size))
74+
# Extract batch and run model
75+
starts = list(itertools.islice(patch_starts_generator, batch_size))
7876
patches = _predict_batch(img, model, starts, patch_size, trim)
7977

8078
# Store result
@@ -166,6 +164,7 @@ def read_patch(i):
166164
with torch.no_grad():
167165
outputs = model(inputs).cpu().squeeze(1).numpy()
168166

167+
# Process results
169168
N = outputs.shape[0]
170169
start, end = trim, patch_size - trim
171170
final_shape = (end - start,) * 3

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 47 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Created on Thu Dec 5 14:00:00 2024
2+
Created on Jan 3 12:30:00 2025
33
44
@author: Anna Grim
55
@email: anna.grim@alleninstitute.org
@@ -9,7 +9,7 @@
99
"""
1010

1111
from abc import ABC, abstractmethod
12-
from careamics.transforms.n2v_manipulate import N2VManipulate
12+
from aind_exaspim_dataset_utils.s3_util import get_img_prefix
1313
from concurrent.futures import (
1414
ProcessPoolExecutor,
1515
ThreadPoolExecutor,
@@ -24,15 +24,16 @@
2424
import torch
2525

2626
from aind_exaspim_image_compression.utils import img_util, util
27+
from aind_exaspim_image_compression.utils.img_util import BM4D
2728
from aind_exaspim_image_compression.utils.swc_util import Reader
2829

2930

3031
# --- Custom Datasets ---
3132
class TrainDataset(Dataset):
33+
3234
def __init__(
3335
self,
3436
patch_shape,
35-
transform,
3637
anisotropy=(0.748, 0.748, 1.0),
3738
boundary_buffer=4000,
3839
foreground_sampling_rate=0.5,
@@ -43,15 +44,19 @@ def __init__(
4344
# Class attributes
4445
self.anisotropy = anisotropy
4546
self.boundary_buffer = boundary_buffer
47+
self.denoise_bm4d = BM4D()
4648
self.foreground_sampling_rate = foreground_sampling_rate
4749
self.patch_shape = patch_shape
4850
self.swc_reader = Reader()
49-
self.transform = transform
51+
52+
# Ground truth denoising
53+
5054

5155
# Data structures
5256
self.foreground = dict()
5357
self.imgs = dict()
5458

59+
# --- Ingest data ---
5560
def ingest_img(self, brain_id, img_path, swc_pointer):
5661
self.foreground[brain_id] = self.ingest_swcs(swc_pointer)
5762
self.imgs[brain_id] = img_util.read(img_path)
@@ -73,25 +78,11 @@ def ingest_swcs(self, swc_pointer):
7378
return foreground
7479
return set()
7580

76-
def __len__(self):
77-
"""
78-
Counts the number of whole-brain images in the dataset.
79-
80-
Parameters
81-
----------
82-
None
83-
84-
Returns
85-
-------
86-
int
87-
Number of whole-brain images in the dataset.
88-
"""
89-
return len(self.imgs)
90-
81+
# --- Core Routines ---
9182
def __getitem__(self, dummy_input):
9283
brain_id = self.sample_brain()
9384
voxel = self.sample_voxel(brain_id)
94-
return self.transform(self.get_patch(brain_id, voxel))
85+
return self.denoise_bm4d(self.get_patch(brain_id, voxel))
9586

9687
def sample_brain(self):
9788
return util.sample_once(self.imgs.keys())
@@ -110,6 +101,21 @@ def sample_voxel(self, brain_id):
110101
return tuple(voxel)
111102

112103
# --- Helpers ---
104+
def __len__(self):
105+
"""
106+
Counts the number of whole-brain images in the dataset.
107+
108+
Parameters
109+
----------
110+
None
111+
112+
Returns
113+
-------
114+
int
115+
Number of whole-brain images in the dataset.
116+
"""
117+
return len(self.imgs)
118+
113119
def get_patch(self, brain_id, voxel):
114120
s, e = img_util.get_start_end(voxel, self.patch_shape)
115121
return self.imgs[brain_id][0, 0, s[0]: e[0], s[1]: e[1], s[2]: e[2]]
@@ -124,13 +130,14 @@ def update_foreground_sampling_rate(self, foreground_sampling_rate):
124130

125131

126132
class ValidateDataset(Dataset):
127-
def __init__(self, patch_shape, transform):
133+
134+
def __init__(self, patch_shape):
128135
# Call parent class
129136
super(ValidateDataset, self).__init__()
130137

131138
# Instance attributes
132139
self.patch_shape = patch_shape
133-
self.transform = transform
140+
self.denoise_bm4d = BM4D()
134141

135142
# Data structures
136143
self.ids = list()
@@ -159,7 +166,7 @@ def ingest_img(self, brain_id, img_path):
159166

160167
def ingest_example(self, brain_id, voxel):
161168
# Get clean image
162-
noise, denoised, mn_mx = self.transform(
169+
noise, denoised, mn_mx = self.denoise_bm4d(
163170
self.get_patch(brain_id, voxel)
164171
)
165172

@@ -203,6 +210,7 @@ def __init__(self, dataset, batch_size=16):
203210
-------
204211
None
205212
"""
213+
# Instance attributes
206214
self.dataset = dataset
207215
self.batch_size = batch_size
208216
self.patch_shape = dataset.patch_shape
@@ -232,43 +240,7 @@ def _load_batch(self, idx):
232240
pass
233241

234242

235-
class TrainN2VDataLoader(DataLoader):
236-
"""
237-
DataLoader that uses multithreading to fetch image patches from the cloud
238-
to form batches to train Noise2Void (N2V).
239-
"""
240-
241-
def __init__(self, dataset, batch_size=16, n_upds=100):
242-
# Call parent class
243-
super().__init__(dataset, batch_size)
244-
245-
# Instance attributes
246-
self.n_upds = n_upds
247-
248-
def _get_iterator(self):
249-
return range(self.n_upds)
250-
251-
def _load_batch(self, dummy_input):
252-
with ThreadPoolExecutor() as executor:
253-
# Assign threads
254-
threads = list()
255-
for _ in range(self.batch_size):
256-
threads.append(executor.submit(self.dataset.__getitem__, -1))
257-
258-
# Process results
259-
shape = (self.batch_size, 1,) + self.patch_shape
260-
masked_patches = np.zeros(shape)
261-
patches = np.zeros(shape)
262-
masks = np.zeros(shape)
263-
for i, thread in enumerate(as_completed(threads)):
264-
masked_patch, patch, mask = thread.result()
265-
masked_patches[i, 0, ...] = masked_patch
266-
patches[i, 0, ...] = patch
267-
masks[i, 0, ...] = mask
268-
return to_tensor(masked_patches), to_tensor(patches), to_tensor(masks)
269-
270-
271-
class TrainBM4DDataLoader(DataLoader):
243+
class TrainDataLoader(DataLoader):
272244
"""
273245
DataLoader that uses multithreading to fetch image patches from the cloud
274246
to form batches.
@@ -282,8 +254,11 @@ def __init__(self, dataset, batch_size=8, n_upds=20):
282254
----------
283255
dataset : Dataset.ProposalDataset
284256
Instance of custom dataset.
285-
batch_size : int
286-
Number of samples per batch.
257+
batch_size : int, optional
258+
Number of samples per batch. Default is 8.
259+
n_upds : int, optional
260+
Number of back propagation gradient updates before validating the
261+
model. Default is 20.
287262
288263
Returns
289264
-------
@@ -316,45 +291,7 @@ def _load_batch(self, dummy_input):
316291
return to_tensor(noise_patches), to_tensor(clean_patches), None
317292

318293

319-
class ValidateN2VDataLoader(DataLoader):
320-
"""
321-
DataLoader that uses multithreading to fetch image patches from the cloud
322-
to form batches.
323-
"""
324-
325-
def __init__(self, dataset, batch_size=8):
326-
super().__init__(dataset, batch_size)
327-
328-
def _get_iterator(self):
329-
return range(0, len(self.dataset), self.batch_size)
330-
331-
def _load_batch(self, start_idx):
332-
# Compute batch size
333-
n_remaining_examples = len(self.dataset) - start_idx
334-
batch_size = min(self.batch_size, n_remaining_examples)
335-
336-
# Generate batch
337-
with ThreadPoolExecutor() as executor:
338-
# Assign threads
339-
threads = list()
340-
for idx_shift in range(batch_size):
341-
idx = start_idx + idx_shift
342-
threads.append(executor.submit(self.dataset.__getitem__, idx))
343-
344-
# Process results
345-
shape = (batch_size, 1,) + self.patch_shape
346-
masked_patches = np.zeros(shape)
347-
patches = np.zeros(shape)
348-
masks = np.zeros(shape)
349-
for i, thread in enumerate(as_completed(threads)):
350-
masked_patch, patch, mask = thread.result()
351-
masked_patches[i, 0, ...] = masked_patch
352-
patches[i, 0, ...] = patch
353-
masks[i, 0, ...] = mask
354-
return to_tensor(masked_patches), to_tensor(patches), to_tensor(masks)
355-
356-
357-
class ValidateBM4DDataLoader(DataLoader):
294+
class ValidateDataLoader(DataLoader):
358295
"""
359296
DataLoader that uses multiprocessing to fetch image patches from the cloud
360297
to form batches.
@@ -399,30 +336,30 @@ def init_datasets(
399336
brain_ids,
400337
img_paths_json,
401338
patch_shape,
402-
n_validate_examples,
403339
foreground_sampling_rate=0.5,
404-
method="bm4d",
340+
n_validate_examples=0,
405341
swc_dict=None
406342
):
407343
# Initializations
408-
transform = N2VManipulate() if method == "n2v" else img_util.BM4D()
409344
train_dataset = TrainDataset(
410-
patch_shape,
411-
transform,
412-
foreground_sampling_rate=foreground_sampling_rate,
345+
patch_shape, foreground_sampling_rate=foreground_sampling_rate,
413346
)
414-
val_dataset = ValidateDataset(patch_shape, transform)
347+
val_dataset = ValidateDataset(patch_shape)
415348

416349
# Load data
417350
for brain_id in tqdm(brain_ids, desc="Load Data"):
418-
img_path = img_util.get_img_prefix(brain_id, img_paths_json)
351+
# Set image path
352+
img_path = get_img_prefix(brain_id, img_paths_json)
419353
img_path += str(0)
354+
355+
# Set SWC path
420356
if swc_dict:
421357
swc_pointer = deepcopy(swc_dict)
422358
swc_pointer["path"] += f"/{brain_id}/world"
423359
else:
424360
swc_pointer = None
425361

362+
# Ingest data
426363
train_dataset.ingest_img(brain_id, img_path, swc_pointer)
427364
val_dataset.ingest_img(brain_id, img_path)
428365

@@ -436,7 +373,7 @@ def init_datasets(
436373

437374
def to_tensor(arr):
438375
"""
439-
Converts a numpy array to a torch tensor.
376+
Converts the given numpy array to a torch tensor.
440377
441378
Parameters
442379
----------

0 commit comments

Comments
 (0)