Skip to content

Commit 87753d5

Browse files
committed
ruffed
1 parent 4e8c4e6 commit 87753d5

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

CollaborativeCoding/metrics/accuracy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import numpy as np
12
import torch
23
from torch import nn
3-
import numpy as np
44

55

66
class Accuracy(nn.Module):
@@ -78,22 +78,20 @@ def _micro_acc(self):
7878
Micro-average accuracy score.
7979
"""
8080
return (self.y_true == self.y_pred).float().mean().item()
81-
81+
8282
def __returnmetric__(self):
8383
if self.y_true == [] or self.y_pred == []:
8484
return np.nan
85-
if isinstance(self.y_true,list):
85+
if isinstance(self.y_true, list):
8686
if len(self.y_true) == 1:
8787
self.y_true = self.y_true[0]
8888
self.y_pred = self.y_pred[0]
8989
else:
9090
self.y_true = torch.cat(self.y_true)
9191
self.y_pred = torch.cat(self.y_pred)
9292
return self._micro_acc() if not self.macro_averaging else self._macro_acc()
93-
93+
9494
def __reset__(self):
9595
self.y_true = []
9696
self.y_pred = []
9797
return None
98-
99-

main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import numpy as np
22
import torch as th
33
import torch.nn as nn
4-
import wandb
54
from torch.utils.data import DataLoader
65
from torchvision import transforms
76
from tqdm import tqdm
8-
#from wandb_api import WANDB_API
97

8+
import wandb
109
from CollaborativeCoding import (
1110
MetricWrapper,
1211
createfolders,
@@ -15,6 +14,9 @@
1514
load_model,
1615
)
1716

17+
# from wandb_api import WANDB_API
18+
19+
1820

1921
def main():
2022
"""
@@ -126,7 +128,7 @@ def main():
126128
print("Dry run completed successfully.")
127129
exit()
128130

129-
# wandb.login(key=WANDB_API)
131+
# wandb.login(key=WANDB_API)
130132
wandb.init(
131133
entity="ColabCode",
132134
project=args.run_name,

tests/test_metrics.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import pytest
44

55
from CollaborativeCoding.load_metric import MetricWrapper
6-
from CollaborativeCoding.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction
6+
from CollaborativeCoding.metrics import (
7+
Accuracy,
8+
EntropyPrediction,
9+
F1Score,
10+
Precision,
11+
Recall,
12+
)
713

814

915
@pytest.mark.parametrize(
@@ -128,8 +134,8 @@ def test_precision():
128134

129135

130136
def test_accuracy():
131-
import torch
132137
import numpy as np
138+
import torch
133139

134140
# Test the accuracy metric
135141
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
@@ -141,20 +147,25 @@ def test_accuracy():
141147
assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0"
142148
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
143149
accuracy(y_true, y_pred)
144-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184"
150+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
151+
"Expected accuracy to be 0.8333333134651184"
152+
)
145153
accuracy.__reset__()
146154
accuracy.macro_averaging = True
147155
accuracy(y_true, y_pred)
148156
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
149157
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
150158
accuracy(y_true_1, y_pred_1)
151-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651186"
159+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
160+
"Expected accuracy to be 0.8333333134651186"
161+
)
152162
accuracy.macro_averaging = False
153-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, "Expected accuracy to be 0.8333333134651184"
163+
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
164+
"Expected accuracy to be 0.8333333134651184"
165+
)
154166
accuracy.__reset__()
155167

156168

157-
158169
def test_entropypred():
159170
import torch as th
160171

tests/test_models.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import pytest
22
import torch
33

4-
from CollaborativeCoding.models import ChristianModel, JanModel, MagnusModel, SolveigModel
4+
from CollaborativeCoding.models import (
5+
ChristianModel,
6+
JanModel,
7+
MagnusModel,
8+
SolveigModel,
9+
)
510

611

712
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)