File tree Expand file tree Collapse file tree 3 files changed +34
-16
lines changed
Expand file tree Collapse file tree 3 files changed +34
-16
lines changed Original file line number Diff line number Diff line change @@ -73,7 +73,7 @@ def main():
7373 "--dataset" ,
7474 type = str ,
7575 default = "svhn" ,
76- choices = ["svhn" ],
76+ choices = ["svhn" , "usps_0-6" ],
7777 help = "Which dataset to train the model on." ,
7878 )
7979
@@ -119,8 +119,17 @@ def main():
119119 metrics = MetricWrapper (* args .metric )
120120
121121 # Dataset
122- traindata = load_data (args .dataset )
123- validata = load_data (args .dataset )
122+ traindata = load_data (
123+ args .dataset ,
124+ train = True ,
125+ data_path = args .datafolder ,
126+ download = args .download_data ,
127+ )
128+ validata = load_data (
129+ args .dataset ,
130+ train = False ,
131+ data_path = args .datafolder ,
132+ )
124133
125134 trainloader = DataLoader (traindata ,
126135 batch_size = args .batchsize ,
@@ -144,7 +153,7 @@ def main():
144153 # Training loop start
145154 trainingloss = []
146155 model .train ()
147- for x , y in traindata :
156+ for x , y in trainloader :
148157 x , y = x .to (device ), y .to (device )
149158 pred = model .forward (x )
150159
Original file line number Diff line number Diff line change @@ -17,19 +17,21 @@ class USPSDataset0_6(Dataset):
1717
1818 Args
1919 ----
20- path : pathlib.Path
20+ data_path : pathlib.Path
2121 Path to the USPS dataset file.
22- mode : str
23- Mode of the dataset. Must be either 'train' or 'test'.
22+ train : bool, optional
23+ Mode of the dataset.
2424 transform : callable, optional
2525 A function/transform that takes in a sample and returns a transformed version.
26+ download : bool, optional
27+ Whether to download the Dataset.
2628
2729 Attributes
2830 ----------
2931 path : pathlib.Path
3032 Path to the USPS dataset file.
3133 mode : str
32- Mode of the dataset.
34+ Mode of the dataset, either train or test .
3335 transform : callable
3436 A function/transform that takes in a sample and returns a transformed version.
3537 idx : numpy.ndarray
@@ -59,15 +61,21 @@ class USPSDataset0_6(Dataset):
5961 6
6062 """
6163
62- def __init__ (self , path : Path , mode : str = "train" , transform = None ):
64+ def __init__ (
65+ self ,
66+ data_path : Path ,
67+ train : bool = False ,
68+ transform = None ,
69+ download : bool = False ,
70+ ):
6371 super ().__init__ ()
64- self .path = path
65- self .mode = mode
72+ self .path = list (data_path .glob ("*.h5" ))[0 ]
6673 self .transform = transform
6774
68- if self . mode not in [ "train" , "test" ] :
69- raise ValueError ( "Invalid mode. Must be either 'train' or 'test' " )
75+ if download :
76+ raise NotImplementedError ( "Download functionality not implemented. " )
7077
78+ self .mode = "train" if train else "test"
7179 self .idx = self ._index ()
7280
7381 def _index (self ):
Original file line number Diff line number Diff line change 1- from dataloaders import USPS_0_6
21from torch .utils .data import Dataset
32
3+ from .dataloaders import USPSDataset0_6
44
5- def load_data (dataset : str ) -> Dataset :
5+
6+ def load_data (dataset : str , * args , ** kwargs ) -> Dataset :
67 match dataset .lower ():
78 case "usps_0-6" :
8- return USPS_0_6
9+ return USPSDataset0_6 ( * args , ** kwargs )
910 case _:
1011 raise ValueError (f"Dataset: { dataset } not implemented." )
You can’t perform that action at this time.
0 commit comments