2626
2727# --- Custom Datasets ---
2828class TrainDataset (Dataset ):
29+ """
30+ A PyTorch Dataset for sampling 3D patches from whole-brain images and
31+ applying the BM4D denoising algorithm. The dataset's __getitem__ method
32+ returns both the original and denoised patches. Optionally, the patch
33+ sampling maybe biased toward foreground regions by using the voxel
34+ coordinates from SWC files that represent neuron tracings.
35+ """
2936
3037 def __init__ (
3138 self ,
3239 patch_shape ,
3340 anisotropy = (0.748 , 0.748 , 1.0 ),
3441 boundary_buffer = 4000 ,
35- foreground_sampling_rate = 0.25 ,
42+ foreground_sampling_rate = 0.5 ,
43+ sigma = 50 ,
3644 ):
3745 # Call parent class
3846 super (TrainDataset , self ).__init__ ()
3947
4048 # Class attributes
4149 self .anisotropy = anisotropy
4250 self .boundary_buffer = boundary_buffer
43- self .denoise_bm4d = BM4D ()
51+ self .denoise_bm4d = BM4D (sigma = sigma )
4452 self .foreground_sampling_rate = foreground_sampling_rate
4553 self .patch_shape = patch_shape
4654 self .swc_reader = Reader ()
@@ -50,11 +58,11 @@ def __init__(
5058 self .imgs = dict ()
5159
5260 # --- Ingest data ---
53- def ingest_img (self , brain_id , img_path , swc_pointer ):
54- self .foreground [brain_id ] = self .ingest_swcs (swc_pointer )
61+ def ingest_brain (self , brain_id , img_path , swc_pointer ):
62+ self .foreground [brain_id ] = self .load_swcs (swc_pointer )
5563 self .imgs [brain_id ] = img_util .read (img_path )
5664
57- def ingest_swcs (self , swc_pointer ):
65+ def load_swcs (self , swc_pointer ):
5866 if swc_pointer :
5967 # Initializations
6068 swc_dicts = self .swc_reader .read (swc_pointer )
@@ -146,7 +154,7 @@ def __len__(self):
146154 """
147155 return len (self .ids )
148156
149- def ingest_img (self , brain_id , img_path ):
157+ def ingest_brain (self , brain_id , img_path ):
150158 self .imgs [brain_id ] = img_util .read (img_path )
151159
152160 def ingest_example (self , brain_id , voxel ):
@@ -190,10 +198,6 @@ def __init__(self, dataset, batch_size=16):
190198 Dataset to iterated over.
191199 batch_size : int
192200 Number of examples in each batch.
193-
194- Returns
195- -------
196- None
197201 """
198202 # Instance attributes
199203 self .dataset = dataset
@@ -337,8 +341,8 @@ def init_datasets(
337341 swc_pointer = None
338342
339343 # Ingest data
340- train_dataset .ingest_img (brain_id , img_path , swc_pointer )
341- val_dataset .ingest_img (brain_id , img_path )
344+ train_dataset .ingest_brain (brain_id , img_path , swc_pointer )
345+ val_dataset .ingest_brain (brain_id , img_path )
342346
343347 # Generate validation examples
344348 for _ in range (n_validate_examples ):
0 commit comments