1919from torch .utils .data import Dataset
2020from tqdm import tqdm
2121
22- import logging
2322import numpy as np
24- import pytorch_lightning as L
2523import random
2624import torch
2725
2826from aind_exaspim_image_compression .utils import img_util , util
2927from aind_exaspim_image_compression .utils .swc_util import Reader
3028
3129
32- logging .getLogger ("urllib3.connectionpool" ).setLevel (logging .WARNING )
33-
34-
35- class DataModule (L .LightningDataModule ):
36- def __init__ (
37- self ,
38- brain_ids ,
39- img_paths_json ,
40- swc_dir ,
41- batch_size = 16 ,
42- foreground_sampling_rate = 0.5 ,
43- n_upds = 100 ,
44- n_validate_examples = 200 ,
45- patch_shape = (64 , 64 , 64 ),
46- ):
47- # Call parent class
48- super (DataModule , self ).__init__ ()
49-
50- # Instance attributes
51- self .batch_size = batch_size
52- self .brain_ids = brain_ids
53- self .foreground_sampling_rate = foreground_sampling_rate
54- self .n_upds = n_upds
55- self .n_validate_examples = n_validate_examples
56- self .patch_shape = patch_shape
57-
58- # Paths
59- self .img_paths_json = img_paths_json
60- self .swc_dir = swc_dir
61-
62- def prepare_data (self ):
63- pass
64-
65- def setup (self , stage = None ):
66- if stage == "fit" or stage is None :
67- self .train_dataset , self .val_dataset = init_datasets (
68- self .brain_ids ,
69- self .img_paths_json ,
70- self .swc_dir ,
71- self .patch_shape ,
72- self .n_validate_examples ,
73- self .foreground_sampling_rate ,
74- )
75-
76- def train_dataloader (self ):
77- train_dataloader = TrainN2VDataLoader (
78- self .train_dataset , batch_size = self .batch_size , n_upds = self .n_upds
79- )
80- return train_dataloader
81-
82- def val_dataloader (self ):
83- val_dataloader = ValidateN2VDataLoader (
84- self .val_dataset , batch_size = self .batch_size
85- )
86- return val_dataloader
87-
88-
8930# --- Custom Datasets ---
9031class TrainDataset (Dataset ):
9132 def __init__ (
@@ -134,16 +75,16 @@ def ingest_swcs(self, swc_pointer):
13475
13576 def __len__ (self ):
13677 """
137- Counts the number of whole-brain images in dataset.
78+ Counts the number of whole-brain images in the dataset.
13879
13980 Parameters
14081 ----------
14182 None
14283
14384 Returns
14485 -------
145- Number of whole-brain images in dataset.
146-
86+ int
87+ Number of whole-brain images in the dataset.
14788 """
14889 return len (self .imgs )
14990
@@ -200,16 +141,16 @@ def __init__(self, patch_shape, transform):
200141
201142 def __len__ (self ):
202143 """
203- Counts the number of whole-brain images in dataset.
144+ Counts the number of whole-brain images in the dataset.
204145
205146 Parameters
206147 ----------
207148 None
208149
209150 Returns
210151 -------
211- Number of whole-brain images in dataset.
212-
152+ int
153+ Number of whole-brain images in the dataset.
213154 """
214155 return len (self .ids )
215156
@@ -245,15 +186,40 @@ class DataLoader(ABC):
245186 """
246187 DataLoader that uses multithreading to fetch image patches from the cloud
247188 to form batches.
248-
249189 """
250190
251191 def __init__ (self , dataset , batch_size = 16 ):
192+ """
193+ Instantiates a DataLoader object.
194+
195+ Parameters
196+ ----------
197+ dataset : torch.utils.data.Dataset
198+ Dataset to iterated over.
199+ batch_size : int
200+ Number of examples in each batch.
201+
202+ Returns
203+ -------
204+ None
205+ """
252206 self .dataset = dataset
253207 self .batch_size = batch_size
254208 self .patch_shape = dataset .patch_shape
255209
256210 def __iter__ (self ):
211+ """
212+ Iterates over the dataset and yields batches of examples.
213+
214+ Parameters
215+ ----------
216+ None
217+
218+ Returns
219+ -------
220+ iterator
221+ Yields batches of examples.
222+ """
257223 for idx in self ._get_iterator ():
258224 yield self ._load_batch (idx )
259225
@@ -270,7 +236,6 @@ class TrainN2VDataLoader(DataLoader):
270236 """
271237 DataLoader that uses multithreading to fetch image patches from the cloud
272238 to form batches to train Noise2Void (N2V).
273-
274239 """
275240
276241 def __init__ (self , dataset , batch_size = 16 , n_upds = 100 ):
@@ -307,7 +272,6 @@ class TrainBM4DDataLoader(DataLoader):
307272 """
308273 DataLoader that uses multithreading to fetch image patches from the cloud
309274 to form batches.
310-
311275 """
312276
313277 def __init__ (self , dataset , batch_size = 8 , n_upds = 20 ):
@@ -324,7 +288,6 @@ def __init__(self, dataset, batch_size=8, n_upds=20):
324288 Returns
325289 -------
326290 None
327-
328291 """
329292 # Call parent class
330293 super ().__init__ (dataset , batch_size )
@@ -357,7 +320,6 @@ class ValidateN2VDataLoader(DataLoader):
357320 """
358321 DataLoader that uses multithreading to fetch image patches from the cloud
359322 to form batches.
360-
361323 """
362324
363325 def __init__ (self , dataset , batch_size = 8 ):
@@ -396,7 +358,6 @@ class ValidateBM4DDataLoader(DataLoader):
396358 """
397359 DataLoader that uses multiprocessing to fetch image patches from the cloud
398360 to form batches.
399-
400361 """
401362
402363 def __init__ (self , dataset , batch_size = 8 ):
@@ -475,7 +436,7 @@ def init_datasets(
475436
476437def to_tensor (arr ):
477438 """
478- Converts a numpy array to a tensor.
439+ Converts a numpy array to a torch tensor.
479440
480441 Parameters
481442 ----------
@@ -485,7 +446,6 @@ def to_tensor(arr):
485446 Returns
486447 -------
487448 torch.Tensor
488- Array converted to tensor.
489-
449+ Array converted to a torch tensor.
490450 """
491451 return torch .tensor (arr , dtype = torch .float )
0 commit comments