Skip to content

Commit e668667

Browse files
committed
Fix minor bugs regarding metric calculation
1 parent 270013c commit e668667

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/training/train.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def train(self, max_epochs: int):
3939
def run_epoch(self, epoch: int):
4040
self.model.train()
4141

42-
loss_sum = 0
43-
acc1_sum = 0
44-
acc5_sum = 0
45-
mcc_sum = 0
42+
loss_sum = 0.0
43+
acc1_sum = 0.0
44+
acc5_sum = 0.0
45+
mcc_sum = 0.0
4646
num_classes = self.my_dataset.num_classes
4747
num_batches = len(self.train_loader)
4848
label_frequencies = torch.zeros(self.my_dataset.num_classes)
@@ -63,7 +63,7 @@ def run_epoch(self, epoch: int):
6363
outputs, labels, "multiclass", num_classes=num_classes, top_k=5
6464
)
6565
mcc_sum += torchmetrics.functional.matthews_corrcoef(
66-
outputs, labels, num_classes=num_classes
66+
outputs, labels, "multiclass", num_classes=num_classes
6767
)
6868
label_frequencies += torch.bincount(
6969
labels, minlength=self.my_dataset.num_classes
@@ -103,10 +103,10 @@ def run_epoch(self, epoch: int):
103103
def test(self, epoch: int):
104104
self.model.eval()
105105

106-
loss_sum = 0
107-
acc1_sum = 0
108-
acc5_sum = 0
109-
mcc_sum = 0
106+
loss_sum = 0.0
107+
acc1_sum = 0.0
108+
acc5_sum = 0.0
109+
mcc_sum = 0.0
110110
num_classes = self.my_dataset.num_classes
111111
num_batches = len(self.test_loader)
112112
start_time = time()
@@ -124,7 +124,7 @@ def test(self, epoch: int):
124124
outputs, labels, "multiclass", num_classes=num_classes, top_k=5
125125
)
126126
mcc_sum += torchmetrics.functional.matthews_corrcoef(
127-
outputs, labels, num_classes=num_classes
127+
outputs, labels, "multiclass", num_classes=num_classes
128128
)
129129

130130
self.log_minibatch(labels, i, train=False)

0 commit comments

Comments
 (0)