Skip to content

Commit e29b36d

Browse files
authored
refactor: removed data module, doc
1 parent 5e2035e commit e29b36d

File tree

1 file changed

+34
-74
lines changed

1 file changed

+34
-74
lines changed

src/aind_exaspim_image_compression/machine_learning/data_handling.py

Lines changed: 34 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,73 +19,14 @@
1919
from torch.utils.data import Dataset
2020
from tqdm import tqdm
2121

22-
import logging
2322
import numpy as np
24-
import pytorch_lightning as L
2523
import random
2624
import torch
2725

2826
from aind_exaspim_image_compression.utils import img_util, util
2927
from aind_exaspim_image_compression.utils.swc_util import Reader
3028

3129

32-
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
33-
34-
35-
class DataModule(L.LightningDataModule):
36-
def __init__(
37-
self,
38-
brain_ids,
39-
img_paths_json,
40-
swc_dir,
41-
batch_size=16,
42-
foreground_sampling_rate=0.5,
43-
n_upds=100,
44-
n_validate_examples=200,
45-
patch_shape=(64, 64, 64),
46-
):
47-
# Call parent class
48-
super(DataModule, self).__init__()
49-
50-
# Instance attributes
51-
self.batch_size = batch_size
52-
self.brain_ids = brain_ids
53-
self.foreground_sampling_rate = foreground_sampling_rate
54-
self.n_upds = n_upds
55-
self.n_validate_examples = n_validate_examples
56-
self.patch_shape = patch_shape
57-
58-
# Paths
59-
self.img_paths_json = img_paths_json
60-
self.swc_dir = swc_dir
61-
62-
def prepare_data(self):
63-
pass
64-
65-
def setup(self, stage=None):
66-
if stage == "fit" or stage is None:
67-
self.train_dataset, self.val_dataset = init_datasets(
68-
self.brain_ids,
69-
self.img_paths_json,
70-
self.swc_dir,
71-
self.patch_shape,
72-
self.n_validate_examples,
73-
self.foreground_sampling_rate,
74-
)
75-
76-
def train_dataloader(self):
77-
train_dataloader = TrainN2VDataLoader(
78-
self.train_dataset, batch_size=self.batch_size, n_upds=self.n_upds
79-
)
80-
return train_dataloader
81-
82-
def val_dataloader(self):
83-
val_dataloader = ValidateN2VDataLoader(
84-
self.val_dataset, batch_size=self.batch_size
85-
)
86-
return val_dataloader
87-
88-
8930
# --- Custom Datasets ---
9031
class TrainDataset(Dataset):
9132
def __init__(
@@ -134,16 +75,16 @@ def ingest_swcs(self, swc_pointer):
13475

13576
def __len__(self):
13677
"""
137-
Counts the number of whole-brain images in dataset.
78+
Counts the number of whole-brain images in the dataset.
13879
13980
Parameters
14081
----------
14182
None
14283
14384
Returns
14485
-------
145-
Number of whole-brain images in dataset.
146-
86+
int
87+
Number of whole-brain images in the dataset.
14788
"""
14889
return len(self.imgs)
14990

@@ -200,16 +141,16 @@ def __init__(self, patch_shape, transform):
200141

201142
def __len__(self):
202143
"""
203-
Counts the number of whole-brain images in dataset.
144+
Counts the number of whole-brain images in the dataset.
204145
205146
Parameters
206147
----------
207148
None
208149
209150
Returns
210151
-------
211-
Number of whole-brain images in dataset.
212-
152+
int
153+
Number of whole-brain images in the dataset.
213154
"""
214155
return len(self.ids)
215156

@@ -245,15 +186,40 @@ class DataLoader(ABC):
245186
"""
246187
DataLoader that uses multithreading to fetch image patches from the cloud
247188
to form batches.
248-
249189
"""
250190

251191
def __init__(self, dataset, batch_size=16):
192+
"""
193+
Instantiates a DataLoader object.
194+
195+
Parameters
196+
----------
197+
dataset : torch.utils.data.Dataset
198+
Dataset to iterated over.
199+
batch_size : int
200+
Number of examples in each batch.
201+
202+
Returns
203+
-------
204+
None
205+
"""
252206
self.dataset = dataset
253207
self.batch_size = batch_size
254208
self.patch_shape = dataset.patch_shape
255209

256210
def __iter__(self):
211+
"""
212+
Iterates over the dataset and yields batches of examples.
213+
214+
Parameters
215+
----------
216+
None
217+
218+
Returns
219+
-------
220+
iterator
221+
Yields batches of examples.
222+
"""
257223
for idx in self._get_iterator():
258224
yield self._load_batch(idx)
259225

@@ -270,7 +236,6 @@ class TrainN2VDataLoader(DataLoader):
270236
"""
271237
DataLoader that uses multithreading to fetch image patches from the cloud
272238
to form batches to train Noise2Void (N2V).
273-
274239
"""
275240

276241
def __init__(self, dataset, batch_size=16, n_upds=100):
@@ -307,7 +272,6 @@ class TrainBM4DDataLoader(DataLoader):
307272
"""
308273
DataLoader that uses multithreading to fetch image patches from the cloud
309274
to form batches.
310-
311275
"""
312276

313277
def __init__(self, dataset, batch_size=8, n_upds=20):
@@ -324,7 +288,6 @@ def __init__(self, dataset, batch_size=8, n_upds=20):
324288
Returns
325289
-------
326290
None
327-
328291
"""
329292
# Call parent class
330293
super().__init__(dataset, batch_size)
@@ -357,7 +320,6 @@ class ValidateN2VDataLoader(DataLoader):
357320
"""
358321
DataLoader that uses multithreading to fetch image patches from the cloud
359322
to form batches.
360-
361323
"""
362324

363325
def __init__(self, dataset, batch_size=8):
@@ -396,7 +358,6 @@ class ValidateBM4DDataLoader(DataLoader):
396358
"""
397359
DataLoader that uses multiprocessing to fetch image patches from the cloud
398360
to form batches.
399-
400361
"""
401362

402363
def __init__(self, dataset, batch_size=8):
@@ -475,7 +436,7 @@ def init_datasets(
475436

476437
def to_tensor(arr):
477438
"""
478-
Converts a numpy array to a tensor.
439+
Converts a numpy array to a torch tensor.
479440
480441
Parameters
481442
----------
@@ -485,7 +446,6 @@ def to_tensor(arr):
485446
Returns
486447
-------
487448
torch.Tensor
488-
Array converted to tensor.
489-
449+
Array converted to a torch tensor.
490450
"""
491451
return torch.tensor(arr, dtype=torch.float)

0 commit comments

Comments
 (0)