Skip to content

Commit 69708c7

Browse files
committed
started fixing my stuff
1 parent 4174cd4 commit 69708c7

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

CollaborativeCoding/metrics/precision.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
4040
if self.macro_averaging
4141
else self._micro_avg_precision(y_true, y_pred)
4242
)
43+
44+
def accumulate(self):
45+
pass # TODO fill
46+
47+
def reset(self):
48+
pass # TODO fill
4349

4450
def _micro_avg_precision(
4551
self, y_true: torch.tensor, y_pred: torch.tensor

tests/test_metrics.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def test_f1score():
7979

8080

8181
def test_precision():
82-
from random import randint
83-
8482
import numpy as np
8583
import torch
8684
from sklearn.metrics import precision_score
@@ -107,12 +105,12 @@ def test_precision():
107105
assert macro_precision_score.item() >= 0, "Expected non-negative value"
108106

109107
# find predictions
110-
y_pred = logits.argmax(dim=-1, keepdims=True)
108+
y_pred = logits.argmax(dim=-1)
111109

112110
# check dimension
113-
assert y_true.shape == torch.Size([N, 1]) or torch.Size([N])
111+
assert y_true.shape == torch.Size([N])
114112
assert logits.shape == torch.Size([N, C])
115-
assert y_pred.shape == torch.Size([N, 1]) or torch.Size([N])
113+
assert y_pred.shape == torch.Size([N])
116114

117115
# find true values with scikit learn
118116
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")

0 commit comments

Comments
 (0)