Skip to content

Commit 3cb36b6

Browse files
authored
Merge pull request #80 from SFI-Visual-Intelligence/johan/devbranch
Updated precision metric to comply with new metric wrapper
2 parents 1f740ce + 867cb2a commit 3cb36b6

File tree

6 files changed

+65
-21
lines changed

6 files changed

+65
-21
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ env2/*
1818
ruffian.sh
1919
localtest.sh
2020

21+
# Johanthings
22+
formatting.x
23+
2124
# Byte-compiled / optimized / DLL files
2225
__pycache__/
2326
*.py[cod]

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@ class MNISTDataset4_9(Dataset):
2020
Whether to train the model or not, by default False
2121
"""
2222

23-
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False):
23+
def __init__(
24+
self,
25+
data_path: Path,
26+
sample_ids: np.ndarray,
27+
train: bool = False,
28+
transform=None,
29+
nr_channels: int = 1,
30+
):
2431
super.__init__()
2532
self.data_path = data_path
2633
self.mnist_path = self.data_path / "MNIST"
@@ -52,4 +59,7 @@ def __getitem__(self, idx):
5259

5360
image = np.expand_dims(image, axis=0) # Channel
5461

62+
if self.transform:
63+
image = self.transform(image)
64+
5565
return image, label

CollaborativeCoding/load_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .dataloaders import (
55
Downloader,
66
MNISTDataset0_3,
7+
MNISTDataset4_9,
78
SVHNDataset,
89
USPSDataset0_6,
910
USPSH5_Digit_7_9_Dataset,
@@ -65,7 +66,9 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
6566
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
6667
labels = np.arange(10)
6768
case "mnist_4-9":
68-
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
69+
dataset = MNISTDataset4_9
70+
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
71+
labels = np.arange(4, 10)
6972
case _:
7073
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
7174

CollaborativeCoding/load_metric.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class MetricWrapper(nn.Module):
2525
-------
2626
__call__(y_true, y_pred)
2727
Computes the specified metrics on the provided true and predicted labels.
28-
__getmetrics__(str_prefix: str = None)
28+
getmetrics(str_prefix: str = None)
2929
Retrieves the computed metrics, optionally prefixed with a string.
30-
reset()
30+
resetmetric()
3131
Resets the state of all metric computations.
3232
Examples
3333
--------
@@ -36,10 +36,10 @@ class MetricWrapper(nn.Module):
3636
>>> y_true = [0, 1, 0, 1]
3737
>>> y_pred = [0, 1, 1, 0]
3838
>>> metrics(y_true, y_pred)
39-
>>> metrics.__getmetrics__()
39+
>>> metrics.getmetrics()
4040
{'entropy': 0.6931471805599453, 'f1': 0.5, 'precision': 0.5}
41-
>>> metrics.reset()
42-
>>> metrics.__getmetrics__()
41+
>>> metrics.resetmetric()
42+
>>> metrics.getmetrics()
4343
{'entropy': [], 'f1': [], 'precision': []}
4444
"""
4545

CollaborativeCoding/metrics/precision.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
34

@@ -18,6 +19,8 @@ def __init__(self, num_classes: int, macro_averaging: bool = False):
1819

1920
self.num_classes = num_classes
2021
self.macro_averaging = macro_averaging
22+
self.y_true = []
23+
self.y_pred = []
2124

2225
def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
2326
"""Compute precision of model
@@ -35,11 +38,10 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
3538
Precision score
3639
"""
3740
y_pred = logits.argmax(dim=-1)
38-
return (
39-
self._macro_avg_precision(y_true, y_pred)
40-
if self.macro_averaging
41-
else self._micro_avg_precision(y_true, y_pred)
42-
)
41+
42+
# Append to the class-global values
43+
self.y_true.append(y_true)
44+
self.y_pred.append(y_pred)
4345

4446
def _micro_avg_precision(
4547
self, y_true: torch.tensor, y_pred: torch.tensor
@@ -58,7 +60,6 @@ def _micro_avg_precision(
5860
torch.tensor
5961
Micro-averaged precision
6062
"""
61-
print(y_true.shape)
6263
true_oh = torch.zeros(y_true.size(0), self.num_classes).scatter_(
6364
1, y_true.unsqueeze(1), 1
6465
)
@@ -98,6 +99,31 @@ def _macro_avg_precision(
9899

99100
return torch.nanmean(tp / (tp + fp))
100101

102+
def __returnmetric__(self):
103+
if self.y_true == [] and self.y_pred == []:
104+
return np.nan
105+
elif self.y_true == [] or self.y_pred == []:
106+
raise ValueError("y_true or y_pred is empty.")
107+
self.y_true = torch.cat(self.y_true)
108+
self.y_pred = torch.cat(self.y_pred)
109+
110+
return (
111+
self._macro_avg_precision(self.y_true, self.y_pred)
112+
if self.macro_averaging
113+
else self._micro_avg_precision(self.y_true, self.y_pred)
114+
)
115+
116+
def __reset__(self):
117+
"""Resets the class-global lists of true and predicted values to empty lists.
118+
119+
Returns
120+
-------
121+
None
122+
Returns None
123+
"""
124+
self.y_true = self.y_pred = []
125+
return None
126+
101127

102128
if __name__ == "__main__":
103129
print(

tests/test_metrics.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ def test_f1score():
8585

8686

8787
def test_precision():
88-
from random import randint
89-
9088
import numpy as np
9189
import torch
9290
from sklearn.metrics import precision_score
@@ -100,9 +98,13 @@ def test_precision():
10098
precision_micro = Precision(num_classes=C)
10199
precision_macro = Precision(num_classes=C, macro_averaging=True)
102100

103-
# find scores
104-
micro_precision_score = precision_micro(y_true, logits)
105-
macro_precision_score = precision_macro(y_true, logits)
101+
# run metric object
102+
precision_micro(y_true, logits)
103+
precision_macro(y_true, logits)
104+
105+
# get metric scores
106+
micro_precision_score = precision_micro.__returnmetric__()
107+
macro_precision_score = precision_macro.__returnmetric__()
106108

107109
# check output to be tensor
108110
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."
@@ -113,12 +115,12 @@ def test_precision():
113115
assert macro_precision_score.item() >= 0, "Expected non-negative value"
114116

115117
# find predictions
116-
y_pred = logits.argmax(dim=-1, keepdims=True)
118+
y_pred = logits.argmax(dim=-1)
117119

118120
# check dimension
119-
assert y_true.shape == torch.Size([N, 1]) or torch.Size([N])
121+
assert y_true.shape == torch.Size([N])
120122
assert logits.shape == torch.Size([N, C])
121-
assert y_pred.shape == torch.Size([N, 1]) or torch.Size([N])
123+
assert y_pred.shape == torch.Size([N])
122124

123125
# find true values with scikit learn
124126
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")

0 commit comments

Comments
 (0)