@@ -107,18 +107,22 @@ def main():
107107 optimizer .step ()
108108 optimizer .zero_grad (set_to_none = True )
109109
110- preds = th .argmax (logits , dim = 1 )
111- metrics (y , preds )
110+ metrics (y , logits )
112111
113112 break
114113 print (metrics .accumulate ())
115114 print ("Dry run completed successfully." )
116115 exit (0 )
117116
118- wandb .login (key = WANDB_API )
119- wandb .init (entity = "ColabCode" , project = "Jan" , tags = [args .modelname , args .dataset ])
117+ # wandb.login(key=WANDB_API)
118+ wandb .init (
119+ entity = "ColabCode-org" ,
120+ # entity="FYS-8805 Exam",
121+ project = "Test" ,
122+ tags = [args .modelname , args .dataset ]
123+ )
120124 wandb .watch (model )
121-
125+ exit ()
122126 for epoch in range (args .epoch ):
123127 # Training loop start
124128 trainingloss = []
@@ -134,8 +138,7 @@ def main():
134138 optimizer .zero_grad (set_to_none = True )
135139 trainingloss .append (loss .item ())
136140
137- preds = th .argmax (logits , dim = 1 )
138- metrics (y , preds )
141+ metrics (y , logits )
139142
140143 wandb .log (metrics .accumulate (str_prefix = "Train " ))
141144 metrics .reset ()
@@ -150,8 +153,7 @@ def main():
150153 loss = criterion (logits , y )
151154 evalloss .append (loss .item ())
152155
153- preds = th .argmax (logits , dim = 1 )
154- metrics (y , preds )
156+ metrics (y , logits )
155157
156158 wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
157159 metrics .reset ()
0 commit comments