@@ -31,7 +31,6 @@ def main():
3131
3232 device = args .device
3333
34-
3534 if args .dataset .lower () == "usps_0-6" or args .dataset .lower () == "uspsh5_7_9" :
3635 augmentations = transforms .Compose (
3736 [
@@ -58,8 +57,8 @@ def main():
5857 transform = augmentations ,
5958 )
6059
61- metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes )
62-
60+ metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes )
61+
6362 # Find the shape of the data, if is 2D, add a channel dimension
6463 data_shape = traindata [0 ][0 ].shape
6564 if len (data_shape ) == 2 :
@@ -106,18 +105,17 @@ def main():
106105
107106 optimizer .step ()
108107 optimizer .zero_grad (set_to_none = True )
109-
108+
110109 preds = th .argmax (logits , dim = 1 )
111110 metrics (y , preds )
112111
113-
114112 break
115113 print (metrics .__getmetrics__ ())
116114 print ("Dry run completed successfully." )
117115 exit (0 )
118116
119117 wandb .login (key = WANDB_API )
120- wandb .init (entity = "ColabCode" ,project = "Jan" , tags = [args .modelname , args .dataset ])
118+ wandb .init (entity = "ColabCode" , project = "Jan" , tags = [args .modelname , args .dataset ])
121119 wandb .watch (model )
122120
123121 for epoch in range (args .epoch ):
@@ -134,10 +132,10 @@ def main():
134132 optimizer .step ()
135133 optimizer .zero_grad (set_to_none = True )
136134 trainingloss .append (loss .item ())
137-
135+
138136 preds = th .argmax (logits , dim = 1 )
139137 metrics (y , preds )
140-
138+
141139 wandb .log (metrics .__getmetrics__ (str_prefix = "Train " ))
142140 metrics .__resetvalues__ ()
143141
@@ -150,10 +148,10 @@ def main():
150148 logits = model .forward (x )
151149 loss = criterion (logits , y )
152150 evalloss .append (loss .item ())
153-
151+
154152 preds = th .argmax (logits , dim = 1 )
155153 metrics (y , preds )
156-
154+
157155 wandb .log (metrics .__getmetrics__ (str_prefix = "Evaluation " ))
158156 metrics .__resetvalues__ ()
159157
0 commit comments