11"""
2- Created on Thu Dec 5 14:00 :00 2024
2+ Created on Jan 3 12:30 :00 2025
33
44@author: Anna Grim
55@email: anna.grim@alleninstitute.org
99"""
1010
1111from abc import ABC , abstractmethod
12- from careamics . transforms . n2v_manipulate import N2VManipulate
12+ from aind_exaspim_dataset_utils . s3_util import get_img_prefix
1313from concurrent .futures import (
1414 ProcessPoolExecutor ,
1515 ThreadPoolExecutor ,
2424import torch
2525
2626from aind_exaspim_image_compression .utils import img_util , util
27+ from aind_exaspim_image_compression .utils .img_util import BM4D
2728from aind_exaspim_image_compression .utils .swc_util import Reader
2829
2930
3031# --- Custom Datasets ---
3132class TrainDataset (Dataset ):
33+
3234 def __init__ (
3335 self ,
3436 patch_shape ,
35- transform ,
3637 anisotropy = (0.748 , 0.748 , 1.0 ),
3738 boundary_buffer = 4000 ,
3839 foreground_sampling_rate = 0.5 ,
@@ -43,15 +44,19 @@ def __init__(
4344 # Class attributes
4445 self .anisotropy = anisotropy
4546 self .boundary_buffer = boundary_buffer
47+ self .denoise_bm4d = BM4D ()
4648 self .foreground_sampling_rate = foreground_sampling_rate
4749 self .patch_shape = patch_shape
4850 self .swc_reader = Reader ()
49- self .transform = transform
51+
52+ # Ground truth denoising
53+
5054
5155 # Data structures
5256 self .foreground = dict ()
5357 self .imgs = dict ()
5458
59+ # --- Ingest data ---
5560 def ingest_img (self , brain_id , img_path , swc_pointer ):
5661 self .foreground [brain_id ] = self .ingest_swcs (swc_pointer )
5762 self .imgs [brain_id ] = img_util .read (img_path )
@@ -73,25 +78,11 @@ def ingest_swcs(self, swc_pointer):
7378 return foreground
7479 return set ()
7580
76- def __len__ (self ):
77- """
78- Counts the number of whole-brain images in the dataset.
79-
80- Parameters
81- ----------
82- None
83-
84- Returns
85- -------
86- int
87- Number of whole-brain images in the dataset.
88- """
89- return len (self .imgs )
90-
81+ # --- Core Routines ---
9182 def __getitem__ (self , dummy_input ):
9283 brain_id = self .sample_brain ()
9384 voxel = self .sample_voxel (brain_id )
94- return self .transform (self .get_patch (brain_id , voxel ))
85+ return self .denoise_bm4d (self .get_patch (brain_id , voxel ))
9586
9687 def sample_brain (self ):
9788 return util .sample_once (self .imgs .keys ())
@@ -110,6 +101,21 @@ def sample_voxel(self, brain_id):
110101 return tuple (voxel )
111102
112103 # --- Helpers ---
104+ def __len__ (self ):
105+ """
106+ Counts the number of whole-brain images in the dataset.
107+
108+ Parameters
109+ ----------
110+ None
111+
112+ Returns
113+ -------
114+ int
115+ Number of whole-brain images in the dataset.
116+ """
117+ return len (self .imgs )
118+
113119 def get_patch (self , brain_id , voxel ):
114120 s , e = img_util .get_start_end (voxel , self .patch_shape )
115121 return self .imgs [brain_id ][0 , 0 , s [0 ]: e [0 ], s [1 ]: e [1 ], s [2 ]: e [2 ]]
@@ -124,13 +130,14 @@ def update_foreground_sampling_rate(self, foreground_sampling_rate):
124130
125131
126132class ValidateDataset (Dataset ):
127- def __init__ (self , patch_shape , transform ):
133+
134+ def __init__ (self , patch_shape ):
128135 # Call parent class
129136 super (ValidateDataset , self ).__init__ ()
130137
131138 # Instance attributes
132139 self .patch_shape = patch_shape
133- self .transform = transform
140+ self .denoise_bm4d = BM4D ()
134141
135142 # Data structures
136143 self .ids = list ()
@@ -159,7 +166,7 @@ def ingest_img(self, brain_id, img_path):
159166
160167 def ingest_example (self , brain_id , voxel ):
161168 # Get clean image
162- noise , denoised , mn_mx = self .transform (
169+ noise , denoised , mn_mx = self .denoise_bm4d (
163170 self .get_patch (brain_id , voxel )
164171 )
165172
@@ -203,6 +210,7 @@ def __init__(self, dataset, batch_size=16):
203210 -------
204211 None
205212 """
213+ # Instance attributes
206214 self .dataset = dataset
207215 self .batch_size = batch_size
208216 self .patch_shape = dataset .patch_shape
@@ -232,43 +240,7 @@ def _load_batch(self, idx):
232240 pass
233241
234242
235- class TrainN2VDataLoader (DataLoader ):
236- """
237- DataLoader that uses multithreading to fetch image patches from the cloud
238- to form batches to train Noise2Void (N2V).
239- """
240-
241- def __init__ (self , dataset , batch_size = 16 , n_upds = 100 ):
242- # Call parent class
243- super ().__init__ (dataset , batch_size )
244-
245- # Instance attributes
246- self .n_upds = n_upds
247-
248- def _get_iterator (self ):
249- return range (self .n_upds )
250-
251- def _load_batch (self , dummy_input ):
252- with ThreadPoolExecutor () as executor :
253- # Assign threads
254- threads = list ()
255- for _ in range (self .batch_size ):
256- threads .append (executor .submit (self .dataset .__getitem__ , - 1 ))
257-
258- # Process results
259- shape = (self .batch_size , 1 ,) + self .patch_shape
260- masked_patches = np .zeros (shape )
261- patches = np .zeros (shape )
262- masks = np .zeros (shape )
263- for i , thread in enumerate (as_completed (threads )):
264- masked_patch , patch , mask = thread .result ()
265- masked_patches [i , 0 , ...] = masked_patch
266- patches [i , 0 , ...] = patch
267- masks [i , 0 , ...] = mask
268- return to_tensor (masked_patches ), to_tensor (patches ), to_tensor (masks )
269-
270-
271- class TrainBM4DDataLoader (DataLoader ):
243+ class TrainDataLoader (DataLoader ):
272244 """
273245 DataLoader that uses multithreading to fetch image patches from the cloud
274246 to form batches.
@@ -282,8 +254,11 @@ def __init__(self, dataset, batch_size=8, n_upds=20):
282254 ----------
283255 dataset : Dataset.ProposalDataset
284256 Instance of custom dataset.
285- batch_size : int
286- Number of samples per batch.
257+ batch_size : int, optional
258+ Number of samples per batch. Default is 8.
259+ n_upds : int, optional
260+ Number of back propagation gradient updates before validating the
261+ model. Default is 20.
287262
288263 Returns
289264 -------
@@ -316,45 +291,7 @@ def _load_batch(self, dummy_input):
316291 return to_tensor (noise_patches ), to_tensor (clean_patches ), None
317292
318293
319- class ValidateN2VDataLoader (DataLoader ):
320- """
321- DataLoader that uses multithreading to fetch image patches from the cloud
322- to form batches.
323- """
324-
325- def __init__ (self , dataset , batch_size = 8 ):
326- super ().__init__ (dataset , batch_size )
327-
328- def _get_iterator (self ):
329- return range (0 , len (self .dataset ), self .batch_size )
330-
331- def _load_batch (self , start_idx ):
332- # Compute batch size
333- n_remaining_examples = len (self .dataset ) - start_idx
334- batch_size = min (self .batch_size , n_remaining_examples )
335-
336- # Generate batch
337- with ThreadPoolExecutor () as executor :
338- # Assign threads
339- threads = list ()
340- for idx_shift in range (batch_size ):
341- idx = start_idx + idx_shift
342- threads .append (executor .submit (self .dataset .__getitem__ , idx ))
343-
344- # Process results
345- shape = (batch_size , 1 ,) + self .patch_shape
346- masked_patches = np .zeros (shape )
347- patches = np .zeros (shape )
348- masks = np .zeros (shape )
349- for i , thread in enumerate (as_completed (threads )):
350- masked_patch , patch , mask = thread .result ()
351- masked_patches [i , 0 , ...] = masked_patch
352- patches [i , 0 , ...] = patch
353- masks [i , 0 , ...] = mask
354- return to_tensor (masked_patches ), to_tensor (patches ), to_tensor (masks )
355-
356-
357- class ValidateBM4DDataLoader (DataLoader ):
294+ class ValidateDataLoader (DataLoader ):
358295 """
359296 DataLoader that uses multiprocessing to fetch image patches from the cloud
360297 to form batches.
@@ -399,30 +336,30 @@ def init_datasets(
399336 brain_ids ,
400337 img_paths_json ,
401338 patch_shape ,
402- n_validate_examples ,
403339 foreground_sampling_rate = 0.5 ,
404- method = "bm4d" ,
340+ n_validate_examples = 0 ,
405341 swc_dict = None
406342):
407343 # Initializations
408- transform = N2VManipulate () if method == "n2v" else img_util .BM4D ()
409344 train_dataset = TrainDataset (
410- patch_shape ,
411- transform ,
412- foreground_sampling_rate = foreground_sampling_rate ,
345+ patch_shape , foreground_sampling_rate = foreground_sampling_rate ,
413346 )
414- val_dataset = ValidateDataset (patch_shape , transform )
347+ val_dataset = ValidateDataset (patch_shape )
415348
416349 # Load data
417350 for brain_id in tqdm (brain_ids , desc = "Load Data" ):
418- img_path = img_util .get_img_prefix (brain_id , img_paths_json )
351+ # Set image path
352+ img_path = get_img_prefix (brain_id , img_paths_json )
419353 img_path += str (0 )
354+
355+ # Set SWC path
420356 if swc_dict :
421357 swc_pointer = deepcopy (swc_dict )
422358 swc_pointer ["path" ] += f"/{ brain_id } /world"
423359 else :
424360 swc_pointer = None
425361
362+ # Ingest data
426363 train_dataset .ingest_img (brain_id , img_path , swc_pointer )
427364 val_dataset .ingest_img (brain_id , img_path )
428365
@@ -436,7 +373,7 @@ def init_datasets(
436373
437374def to_tensor (arr ):
438375 """
439- Converts a numpy array to a torch tensor.
376+ Converts the given numpy array to a torch tensor.
440377
441378 Parameters
442379 ----------
0 commit comments