Skip to content

Commit 867cb2a

Browse files
committed
ruffedisorted
1 parent 363d20f commit 867cb2a

File tree

5 files changed

+29
-12
lines changed

5 files changed

+29
-12
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ env2/*
1515
ruffian.sh
1616
localtest.sh
1717

18+
# Johanthings
19+
formatting.x
20+
1821
# Byte-compiled / optimized / DLL files
1922
__pycache__/
2023
*.py[cod]

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,14 @@ class MNISTDataset4_9(Dataset):
2020
Whether to train the model or not, by default False
2121
"""
2222

23-
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False, transform = None, nr_channels: int = 1):
23+
def __init__(
24+
self,
25+
data_path: Path,
26+
sample_ids: np.ndarray,
27+
train: bool = False,
28+
transform=None,
29+
nr_channels: int = 1,
30+
):
2431
super.__init__()
2532
self.data_path = data_path
2633
self.mnist_path = self.data_path / "MNIST"
@@ -51,7 +58,7 @@ def __getitem__(self, idx):
5158
)
5259

5360
image = np.expand_dims(image, axis=0) # Channel
54-
61+
5562
if self.transform:
5663
image = self.transform(image)
5764

CollaborativeCoding/load_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
6868
case "mnist_4-9":
6969
dataset = MNISTDataset4_9
7070
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
71-
labels = np.arange(4,10)
71+
labels = np.arange(4, 10)
7272
case _:
7373
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
7474

CollaborativeCoding/metrics/precision.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import numpy as np
12
import torch
23
import torch.nn as nn
3-
import numpy as np
44

55

66
class Precision(nn.Module):
@@ -42,7 +42,6 @@ def forward(self, y_true: torch.tensor, logits: torch.tensor) -> torch.tensor:
4242
# Append to the class-global values
4343
self.y_true.append(y_true)
4444
self.y_pred.append(y_pred)
45-
4645

4746
def _micro_avg_precision(
4847
self, y_true: torch.tensor, y_pred: torch.tensor
@@ -99,17 +98,21 @@ def _macro_avg_precision(
9998
fp = torch.sum(~true_oh.bool() * pred_oh, 0)
10099

101100
return torch.nanmean(tp / (tp + fp))
102-
101+
103102
def __returnmetric__(self):
104103
if self.y_true == [] and self.y_pred == []:
105104
return np.nan
106105
elif self.y_true == [] or self.y_pred == []:
107106
raise ValueError("y_true or y_pred is empty.")
108107
self.y_true = torch.cat(self.y_true)
109108
self.y_pred = torch.cat(self.y_pred)
110-
111-
return self._macro_avg_precision(self.y_true, self.y_pred) if self.macro_averaging else self._micro_avg_precision(self.y_true, self.y_pred)
112-
109+
110+
return (
111+
self._macro_avg_precision(self.y_true, self.y_pred)
112+
if self.macro_averaging
113+
else self._micro_avg_precision(self.y_true, self.y_pred)
114+
)
115+
113116
def __reset__(self):
114117
"""Resets the class-global lists of true and predicted values to empty lists.
115118

tests/test_metrics.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ def test_precision():
9898
precision_micro = Precision(num_classes=C)
9999
precision_macro = Precision(num_classes=C, macro_averaging=True)
100100

101-
# find scores
102-
micro_precision_score = precision_micro(y_true, logits)
103-
macro_precision_score = precision_macro(y_true, logits)
101+
# run metric object
102+
precision_micro(y_true, logits)
103+
precision_macro(y_true, logits)
104+
105+
# get metric scores
106+
micro_precision_score = precision_micro.__returnmetric__()
107+
macro_precision_score = precision_macro.__returnmetric__()
104108

105109
# check output to be tensor
106110
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."

0 commit comments

Comments
 (0)