Skip to content

Commit 1f740ce

Browse files
authored
Merge pull request #79 from SFI-Visual-Intelligence/Jan-dev
Adjusted accuracy to fit new method names
2 parents 6fb5296 + 3a46ddc commit 1f740ce

File tree

8 files changed

+215
-53
lines changed

8 files changed

+215
-53
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ wandb/*
99
wandb_api.py
1010
doc/autoapi
1111

12+
*.DS_Store
13+
14+
1215
#Magnus specific
1316
job*
1417
env2/*

CollaborativeCoding/dataloaders/mnist_0_3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
sample_ids: list,
5454
train: bool = False,
5555
transform=None,
56+
nr_channels: int = 1,
5657
):
5758
super().__init__()
5859

CollaborativeCoding/metrics/accuracy.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
from torch import nn
34

@@ -7,6 +8,8 @@ def __init__(self, num_classes, macro_averaging=False):
78
super().__init__()
89
self.num_classes = num_classes
910
self.macro_averaging = macro_averaging
11+
self.y_true = []
12+
self.y_pred = []
1013

1114
def forward(self, y_true, y_pred):
1215
"""
@@ -26,12 +29,10 @@ def forward(self, y_true, y_pred):
2629
"""
2730
if y_pred.dim() > 1:
2831
y_pred = y_pred.argmax(dim=1)
29-
if self.macro_averaging:
30-
return self._macro_acc(y_true, y_pred)
31-
else:
32-
return self._micro_acc(y_true, y_pred)
32+
self.y_true.append(y_true)
33+
self.y_pred.append(y_pred)
3334

34-
def _macro_acc(self, y_true, y_pred):
35+
def _macro_acc(self):
3536
"""
3637
Compute the macro-average accuracy.
3738
@@ -47,7 +48,7 @@ def _macro_acc(self, y_true, y_pred):
4748
float
4849
Macro-average accuracy score.
4950
"""
50-
y_true, y_pred = y_true.flatten(), y_pred.flatten() # Ensure 1D shape
51+
y_true, y_pred = self.y_true.flatten(), self.y_pred.flatten() # Ensure 1D shape
5152

5253
classes = torch.unique(y_true) # Find unique class labels
5354
acc_per_class = []
@@ -60,7 +61,7 @@ def _macro_acc(self, y_true, y_pred):
6061
macro_acc = torch.stack(acc_per_class).mean().item() # Average across classes
6162
return macro_acc
6263

63-
def _micro_acc(self, y_true, y_pred):
64+
def _micro_acc(self):
6465
"""
6566
Compute the micro-average accuracy.
6667
@@ -76,27 +77,21 @@ def _micro_acc(self, y_true, y_pred):
7677
float
7778
Micro-average accuracy score.
7879
"""
79-
return (y_true == y_pred).float().mean().item()
80-
81-
82-
if __name__ == "__main__":
83-
accuracy = Accuracy(5)
84-
macro_accuracy = Accuracy(5, macro_averaging=True)
85-
86-
y_true = torch.tensor([0, 3, 2, 3, 4])
87-
y_pred = torch.tensor([0, 1, 2, 3, 4])
88-
print(accuracy(y_true, y_pred))
89-
print(macro_accuracy(y_true, y_pred))
90-
91-
y_true = torch.tensor([0, 3, 2, 3, 4])
92-
y_onehot_pred = torch.tensor(
93-
[
94-
[1, 0, 0, 0, 0],
95-
[0, 1, 0, 0, 0],
96-
[0, 0, 1, 0, 0],
97-
[0, 0, 0, 1, 0],
98-
[0, 0, 0, 0, 1],
99-
]
100-
)
101-
print(accuracy(y_true, y_onehot_pred))
102-
print(macro_accuracy(y_true, y_onehot_pred))
80+
return (self.y_true == self.y_pred).float().mean().item()
81+
82+
def __returnmetric__(self):
83+
if self.y_true == [] or self.y_pred == []:
84+
return np.nan
85+
if isinstance(self.y_true, list):
86+
if len(self.y_true) == 1:
87+
self.y_true = self.y_true[0]
88+
self.y_pred = self.y_pred[0]
89+
else:
90+
self.y_true = torch.cat(self.y_true)
91+
self.y_pred = torch.cat(self.y_pred)
92+
return self._micro_acc() if not self.macro_averaging else self._macro_acc()
93+
94+
def __reset__(self):
95+
self.y_true = []
96+
self.y_pred = []
97+
return None

main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
87

8+
import wandb
99
from CollaborativeCoding import (
1010
MetricWrapper,
1111
createfolders,
@@ -14,6 +14,9 @@
1414
load_model,
1515
)
1616

17+
# from wandb_api import WANDB_API
18+
19+
1720

1821
def main():
1922
"""
@@ -126,7 +129,7 @@ def main():
126129
print("Dry run completed successfully.")
127130
exit()
128131

129-
# wandb.login(key=WANDB_API)
132+
# wandb.login(key=WANDB_API)
130133
wandb.init(
131134
entity="ColabCode",
132135
project=args.run_name,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dependencies = [
2323
"torch>=2.6.0",
2424
"torchvision>=0.21.0",
2525
"tqdm>=4.67.1",
26+
"wandb>=0.19.6",
2627
]
2728
[tool.isort]
2829
profile = "black"

tests/test_dataloaders.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
def test_load_data(data_name, expected):
2727
dataset = load_data(
2828
data_name,
29-
data_path=Path("data"),
30-
download=True,
29+
data_dir=Path("data"),
3130
transform=transforms.ToTensor(),
3231
)
3332
assert isinstance(dataset, expected)

tests/test_metrics.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,36 @@ def test_precision():
134134

135135

136136
def test_accuracy():
137+
import numpy as np
137138
import torch
138139

139-
accuracy = Accuracy(num_classes=5)
140-
141-
y_true = torch.tensor([0, 3, 2, 3, 4])
142-
y_pred = torch.tensor([0, 1, 2, 3, 4])
143-
144-
accuracy_score = accuracy(y_true, y_pred)
145-
146-
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
147-
f"Accuracy Score: {accuracy_score.item()}"
140+
# Test the accuracy metric
141+
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
142+
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
143+
accuracy = Accuracy(num_classes=6, macro_averaging=False)
144+
accuracy(y_true, y_pred)
145+
assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0"
146+
accuracy.__reset__()
147+
assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0"
148+
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
149+
accuracy(y_true, y_pred)
150+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
151+
"Expected accuracy to be 0.8333333134651184"
152+
)
153+
accuracy.__reset__()
154+
accuracy.macro_averaging = True
155+
accuracy(y_true, y_pred)
156+
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
157+
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
158+
accuracy(y_true_1, y_pred_1)
159+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
160+
"Expected accuracy to be 0.8333333134651186"
161+
)
162+
accuracy.macro_averaging = False
163+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
164+
"Expected accuracy to be 0.8333333134651184"
148165
)
166+
accuracy.__reset__()
149167

150168

151169
def test_entropypred():

0 commit comments

Comments
 (0)