66import torch .nn as nn
77import wandb
88from torch .utils .data import DataLoader
9+ from torchvision import transforms
10+ from tqdm import tqdm
911
1012from utils import MetricWrapper , createfolders , load_data , load_model
1113
@@ -49,15 +51,13 @@ def main():
4951 )
5052 parser .add_argument (
5153 "--savemodel" ,
52- type = bool ,
53- default = False ,
54+ action = "store_true" ,
5455 help = "Whether model should be saved or not." ,
5556 )
5657
5758 parser .add_argument (
5859 "--download-data" ,
59- type = bool ,
60- default = False ,
60+ action = "store_true" ,
6161 help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
6262 )
6363
@@ -126,17 +126,27 @@ def main():
126126
127127 metrics = MetricWrapper (* args .metric )
128128
129+ augmentations = transforms .Compose (
130+ [
131+ transforms .Resize ((16 , 16 )), # At least for USPS
132+ transforms .ToTensor (),
133+ ]
134+ )
135+
129136 # Dataset
130137 traindata = load_data (
131138 args .dataset ,
132139 train = True ,
133140 data_path = args .datafolder ,
134141 download = args .download_data ,
142+ transform = augmentations ,
135143 )
136144 validata = load_data (
137145 args .dataset ,
138146 train = False ,
139147 data_path = args .datafolder ,
148+ download = args .download_data ,
149+ transform = augmentations ,
140150 )
141151
142152 # Find number of channels in the dataset
@@ -153,34 +163,53 @@ def main():
153163 )
154164 model .to (device )
155165
156- trainloader = DataLoader (traindata ,
157- batch_size = args .batchsize ,
158- shuffle = True ,
159- pin_memory = True ,
160- drop_last = True )
161- valiloader = DataLoader (validata ,
162- batch_size = args .batchsize ,
163- shuffle = False ,
164- pin_memory = True )
166+ trainloader = DataLoader (
167+ traindata ,
168+ batch_size = args .batchsize ,
169+ shuffle = True ,
170+ pin_memory = True ,
171+ drop_last = True ,
172+ )
173+ valiloader = DataLoader (
174+ validata , batch_size = args .batchsize , shuffle = False , pin_memory = True
175+ )
165176
166177 criterion = nn .CrossEntropyLoss ()
167178 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
168179
169180 # This allows us to load all the components without running the training loop
170181 if args .dry_run :
171- print ("Dry run completed" )
182+ dry_run_loader = DataLoader (
183+ traindata ,
184+ batch_size = 1 ,
185+ shuffle = True ,
186+ pin_memory = True ,
187+ drop_last = True ,
188+ )
189+
190+ for x , y in tqdm (dry_run_loader , desc = "Dry run" , total = 1 ):
191+ x , y = x .to (device ), y .to (device )
192+ pred = model .forward (x )
193+
194+ loss = criterion (y , pred )
195+ loss .backward ()
196+
197+ optimizer .step ()
198+ optimizer .zero_grad (set_to_none = True )
199+
200+ break
201+
202+ print ("Dry run completed successfully." )
172203 exit (0 )
173204
174- wandb .init (project = '' ,
175- tags = [])
205+ wandb .init (project = "" , tags = [])
176206 wandb .watch (model )
177207
178208 for epoch in range (args .epoch ):
179-
180209 # Training loop start
181210 trainingloss = []
182211 model .train ()
183- for x , y in trainloader :
212+ for x , y in tqdm ( trainloader , desc = "Training" ) :
184213 x , y = x .to (device ), y .to (device )
185214 pred = model .forward (x )
186215
@@ -195,18 +224,20 @@ def main():
195224 # Eval loop start
196225 model .eval ()
197226 with th .no_grad ():
198- for x , y in valiloader :
227+ for x , y in tqdm ( valiloader , desc = "Validation" ) :
199228 x , y = x .to (device ), y .to (device )
200229 pred = model .forward (x )
201230 loss = criterion (y , pred )
202231 evalloss .append (loss .item ())
203232
204- wandb .log ({
205- 'Epoch' : epoch ,
206- 'Train loss' : np .mean (trainingloss ),
207- 'Evaluation Loss' : np .mean (evalloss )
208- })
233+ wandb .log (
234+ {
235+ "Epoch" : epoch ,
236+ "Train loss" : np .mean (trainingloss ),
237+ "Evaluation Loss" : np .mean (evalloss ),
238+ }
239+ )
209240
210241
211- if __name__ == ' __main__' :
242+ if __name__ == " __main__" :
212243 main ()
0 commit comments