99from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1010from wandb_api import WANDB_API
1111
12+
1213def main ():
1314 """
1415
@@ -46,7 +47,21 @@ def main():
4647 val_size = args .val_size ,
4748 )
4849
49- metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes , macro_averaging = args .macro_averaging )
50+ train_metrics = MetricWrapper (
51+ * args .metric ,
52+ num_classes = traindata .num_classes ,
53+ macro_averaging = args .macro_averaging ,
54+ )
55+ val_metrics = MetricWrapper (
56+ * args .metric ,
57+ num_classes = traindata .num_classes ,
58+ macro_averaging = args .macro_averaging ,
59+ )
60+ test_metrics = MetricWrapper (
61+ * args .metric ,
62+ num_classes = traindata .num_classes ,
63+ macro_averaging = args .macro_averaging ,
64+ )
5065
5166 # Find the shape of the data, if is 2D, add a channel dimension
5267 data_shape = traindata [0 ][0 ].shape
@@ -98,22 +113,22 @@ def main():
98113 optimizer .step ()
99114 optimizer .zero_grad (set_to_none = True )
100115
101- metrics (y , logits )
116+ train_metrics (y , logits )
102117
103118 break
104- print (metrics .accumulate ())
119+ print (train_metrics .accumulate ())
105120 print ("Dry run completed successfully." )
106121 exit ()
107122
108123 # wandb.login(key=WANDB_API)
109124 wandb .init (
110- entity = "ColabCode" ,
111- # entity="FYS-8805 Exam",
112- project = "Jan" ,
113- tags = [args .modelname , args .dataset ]
114- )
125+ entity = "ColabCode" ,
126+ # entity="FYS-8805 Exam",
127+ project = "Jan" ,
128+ tags = [args .modelname , args .dataset ],
129+ )
115130 wandb .watch (model )
116-
131+
117132 for epoch in range (args .epoch ):
118133 # Training loop start
119134 trainingloss = []
@@ -129,10 +144,7 @@ def main():
129144 optimizer .zero_grad (set_to_none = True )
130145 trainingloss .append (loss .item ())
131146
132- metrics (y , logits )
133-
134- wandb .log (metrics .accumulate (str_prefix = "Train " ))
135- metrics .reset ()
147+ train_metrics (y , logits )
136148
137149 valloss = []
138150 # Validation loop start
@@ -144,18 +156,19 @@ def main():
144156 loss = criterion (logits , y )
145157 valloss .append (loss .item ())
146158
147- metrics (y , logits )
148-
149- wandb .log (metrics .accumulate (str_prefix = "Validation " ))
150- metrics .reset ()
159+ val_metrics (y , logits )
151160
152161 wandb .log (
153162 {
154163 "Epoch" : epoch ,
155164 "Train loss" : np .mean (trainingloss ),
156165 "Validation loss" : np .mean (valloss ),
157166 }
167+ | train_metrics .accumulate (str_prefix = "Train " )
168+ | val_metrics .accumulate (str_prefix = "Validation " )
158169 )
170+ train_metrics .reset ()
171+ val_metrics .reset ()
159172
160173 testloss = []
161174 model .eval ()
@@ -167,11 +180,13 @@ def main():
167180 testloss .append (loss .item ())
168181
169182 preds = th .argmax (logits , dim = 1 )
170- metrics (y , preds )
183+ test_metrics (y , preds )
171184
172- wandb .log (metrics .accumulate (str_prefix = "Test " ))
173- metrics .reset ()
174- wandb .log ({"Test loss" : np .mean (testloss )})
185+ wandb .log (
186+ {"Epoch" : 1 , "Test loss" : np .mean (testloss )}
187+ | test_metrics .accumulate (str_prefix = "Test " )
188+ )
189+ test_metrics .reset ()
175190
176191
177192if __name__ == "__main__" :
0 commit comments