Skip to content

Commit 0ebaced

Browse files
committed
formatted to pass tests
1 parent 9f400f5 commit 0ebaced

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

tests/test_metrics.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
2-
from utils.metrics import F1Score, Precision, Recall, Accuracy
3-
1+
from utils.metrics import Accuracy, F1Score, Precision, Recall
42

53

64
def test_recall():
@@ -84,7 +82,8 @@ def test_for_zero_denominator():
8482
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
8583
f"Precision Score: {precision4.item()}"
8684
)
87-
85+
86+
8887
def test_accuracy():
8988
import torch
9089

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_christian_model(image_shape, num_classes):
2121
f"Softmax output should sum to 1, but got: {y.sum()}"
2222
)
2323

24+
2425
@pytest.mark.parametrize(
2526
"image_shape, num_classes",
2627
[((1, 28, 28), 4), ((3, 16, 16), 10)],
@@ -34,4 +35,3 @@ def test_jan_model(image_shape, num_classes):
3435
y = model(x)
3536

3637
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
37-

utils/load_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.nn as nn
55

6-
from .metrics import EntropyPrediction, F1Score, Precision, Accuracy
6+
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision
77

88

99
class MetricWrapper(nn.Module):

utils/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
__all__ = ["EntropyPrediction", "Recall", "F1Score", "Precision", "Accuracy"]
22

3+
from .accuracy import Accuracy
34
from .EntropyPred import EntropyPrediction
45
from .F1 import F1Score
56
from .precision import Precision
67
from .recall import Recall
7-
from .accuracy import Accuracy

utils/metrics/accuracy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ def forward(self, y_true, y_pred):
2323
Accuracy score.
2424
"""
2525
return (y_true == y_pred).float().mean().item()
26-
26+
27+
2728
if __name__ == "__main__":
2829
y_true = torch.tensor([0, 3, 2, 3, 4])
2930
y_pred = torch.tensor([0, 1, 2, 3, 4])
3031

3132
accuracy = Accuracy()
32-
print(accuracy(y_true, y_pred))
33+
print(accuracy(y_true, y_pred))

0 commit comments

Comments
 (0)