@@ -39,10 +39,10 @@ def train(self, max_epochs: int):
39
39
def run_epoch (self , epoch : int ):
40
40
self .model .train ()
41
41
42
- loss_sum = 0
43
- acc1_sum = 0
44
- acc5_sum = 0
45
- mcc_sum = 0
42
+ loss_sum = 0.0
43
+ acc1_sum = 0.0
44
+ acc5_sum = 0.0
45
+ mcc_sum = 0.0
46
46
num_classes = self .my_dataset .num_classes
47
47
num_batches = len (self .train_loader )
48
48
label_frequencies = torch .zeros (self .my_dataset .num_classes )
@@ -63,7 +63,7 @@ def run_epoch(self, epoch: int):
63
63
outputs , labels , "multiclass" , num_classes = num_classes , top_k = 5
64
64
)
65
65
mcc_sum += torchmetrics .functional .matthews_corrcoef (
66
- outputs , labels , num_classes = num_classes
66
+ outputs , labels , "multiclass" , num_classes = num_classes
67
67
)
68
68
label_frequencies += torch .bincount (
69
69
labels , minlength = self .my_dataset .num_classes
@@ -103,10 +103,10 @@ def run_epoch(self, epoch: int):
103
103
def test (self , epoch : int ):
104
104
self .model .eval ()
105
105
106
- loss_sum = 0
107
- acc1_sum = 0
108
- acc5_sum = 0
109
- mcc_sum = 0
106
+ loss_sum = 0.0
107
+ acc1_sum = 0.0
108
+ acc5_sum = 0.0
109
+ mcc_sum = 0.0
110
110
num_classes = self .my_dataset .num_classes
111
111
num_batches = len (self .test_loader )
112
112
start_time = time ()
@@ -124,7 +124,7 @@ def test(self, epoch: int):
124
124
outputs , labels , "multiclass" , num_classes = num_classes , top_k = 5
125
125
)
126
126
mcc_sum += torchmetrics .functional .matthews_corrcoef (
127
- outputs , labels , num_classes = num_classes
127
+ outputs , labels , "multiclass" , num_classes = num_classes
128
128
)
129
129
130
130
self .log_minibatch (labels , i , train = False )
0 commit comments