Skip to content

Commit faac193

Browse files
committed
load_data now gives arguments to the datasets
1 parent 8ef502f commit faac193

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

main.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
"--dataset",
7474
type=str,
7575
default="svhn",
76-
choices=["svhn"],
76+
choices=["svhn", "usps_0-6"],
7777
help="Which dataset to train the model on.",
7878
)
7979

@@ -119,8 +119,17 @@ def main():
119119
metrics = MetricWrapper(*args.metric)
120120

121121
# Dataset
122-
traindata = load_data(args.dataset)
123-
validata = load_data(args.dataset)
122+
traindata = load_data(
123+
args.dataset,
124+
train=True,
125+
data_path=args.datafolder,
126+
download=args.download_data,
127+
)
128+
validata = load_data(
129+
args.dataset,
130+
train=False,
131+
data_path=args.datafolder,
132+
)
124133

125134
trainloader = DataLoader(traindata,
126135
batch_size=args.batchsize,
@@ -144,7 +153,7 @@ def main():
144153
# Training loop start
145154
trainingloss = []
146155
model.train()
147-
for x, y in traindata:
156+
for x, y in trainloader:
148157
x, y = x.to(device), y.to(device)
149158
pred = model.forward(x)
150159

utils/dataloaders/usps_0_6.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,21 @@ class USPSDataset0_6(Dataset):
1717
1818
Args
1919
----
20-
path : pathlib.Path
20+
data_path : pathlib.Path
2121
Path to the USPS dataset file.
22-
mode : str
23-
Mode of the dataset. Must be either 'train' or 'test'.
22+
train : bool, optional
23+
Mode of the dataset.
2424
transform : callable, optional
2525
A function/transform that takes in a sample and returns a transformed version.
26+
download : bool, optional
27+
Whether to download the Dataset.
2628
2729
Attributes
2830
----------
2931
path : pathlib.Path
3032
Path to the USPS dataset file.
3133
mode : str
32-
Mode of the dataset.
34+
Mode of the dataset, either train or test.
3335
transform : callable
3436
A function/transform that takes in a sample and returns a transformed version.
3537
idx : numpy.ndarray
@@ -59,15 +61,21 @@ class USPSDataset0_6(Dataset):
5961
6
6062
"""
6163

62-
def __init__(self, path: Path, mode: str = "train", transform=None):
64+
def __init__(
65+
self,
66+
data_path: Path,
67+
train: bool = False,
68+
transform=None,
69+
download: bool = False,
70+
):
6371
super().__init__()
64-
self.path = path
65-
self.mode = mode
72+
self.path = list(data_path.glob("*.h5"))[0]
6673
self.transform = transform
6774

68-
if self.mode not in ["train", "test"]:
69-
raise ValueError("Invalid mode. Must be either 'train' or 'test'")
75+
if download:
76+
raise NotImplementedError("Download functionality not implemented.")
7077

78+
self.mode = "train" if train else "test"
7179
self.idx = self._index()
7280

7381
def _index(self):

utils/load_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from dataloaders import USPS_0_6
21
from torch.utils.data import Dataset
32

3+
from .dataloaders import USPSDataset0_6
44

5-
def load_data(dataset: str) -> Dataset:
5+
6+
def load_data(dataset: str, *args, **kwargs) -> Dataset:
67
match dataset.lower():
78
case "usps_0-6":
8-
return USPS_0_6
9+
return USPSDataset0_6(*args, **kwargs)
910
case _:
1011
raise ValueError(f"Dataset: {dataset} not implemented.")

0 commit comments

Comments
 (0)