Skip to content

Commit a1acf07

Browse files
committed
pulled from main
2 parents f04c2f3 + b693919 commit a1acf07

File tree

6 files changed

+72
-351
lines changed

6 files changed

+72
-351
lines changed

CollaborativeCoding/dataloaders/mnist_4_9.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,11 @@ def __init__(
4343
self.labels_path = self.mnist_path / (
4444
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
4545
)
46-
47-
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48-
self.label_shift = lambda x: x-4
49-
self.label_restore = lambda x: x+4
50-
51-
46+
47+
# Functions to map the labels from (4,9) -> (0,5) for CrossEntropyLoss to work properly.
48+
self.label_shift = lambda x: x - 4
49+
self.label_restore = lambda x: x + 4
50+
5251
def __len__(self):
5352
return len(self.samples)
5453

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def main():
139139

140140
for epoch in range(args.epoch):
141141
# Training loop start
142-
print(f"Epoch: {epoch+1}/{args.epoch}")
142+
print(f"Epoch: {epoch + 1}/{args.epoch}")
143143
trainingloss = []
144144
model.train()
145145
for x, y in tqdm(trainloader, desc="Training"):

tests/test_dataloaders.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from pathlib import Path
22

3-
import numpy as np
43
import pytest
54
import torch
65
from torchvision import transforms
@@ -26,14 +25,19 @@
2625
],
2726
)
2827
def test_load_data(data_name, expected):
29-
print(data_name)
3028
dataset, _, _ = load_data(
3129
data_name,
32-
data_dir=Path("data"),
30+
train=False,
31+
data_dir=Path("Data"),
3332
transform=transforms.ToTensor(),
3433
)
35-
assert isinstance(dataset, expected)
36-
assert len(dataset) > 0
37-
assert isinstance(dataset[0], tuple)
38-
assert isinstance(dataset[0][0], torch.Tensor)
39-
assert isinstance(dataset[0][1], int)
34+
35+
sample = dataset[0]
36+
img, label = sample
37+
38+
assert isinstance(dataset, expected), f"{type(dataset)} != {expected}"
39+
assert len(dataset) > 0, "Dataset is empty"
40+
assert isinstance(sample, tuple), f"{type(sample)} != tuple"
41+
assert isinstance(img, torch.Tensor), f"{type(img)} != torch.Tensor"
42+
assert isinstance(label, int), f"{type(label)} != int"
43+
assert len(img.size()) == 3, f"{len(img.size())} != 3"

tests/test_metrics.py

Lines changed: 37 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,214 +1,57 @@
1-
from random import randint
1+
import itertools
22

33
import pytest
44

55
from CollaborativeCoding.load_metric import MetricWrapper
6-
from CollaborativeCoding.metrics import (
7-
Accuracy,
8-
EntropyPrediction,
9-
F1Score,
10-
Precision,
11-
Recall,
12-
)
136

14-
15-
@pytest.mark.parametrize(
16-
"metric, num_classes, macro_averaging",
17-
[
18-
("f1", randint(2, 10), False),
19-
("f1", randint(2, 10), True),
20-
("recall", randint(2, 10), False),
21-
("recall", randint(2, 10), True),
22-
("accuracy", randint(2, 10), False),
23-
("accuracy", randint(2, 10), True),
24-
("precision", randint(2, 10), False),
25-
("precision", randint(2, 10), True),
26-
("entropy", randint(2, 10), False),
27-
],
28-
)
29-
def test_metric_wrapper(metric, num_classes, macro_averaging):
30-
import numpy as np
31-
import torch
32-
33-
y_true = torch.arange(num_classes, dtype=torch.int64)
34-
logits = torch.rand(num_classes, num_classes)
35-
36-
metrics = MetricWrapper(
37-
metric,
38-
num_classes=num_classes,
39-
macro_averaging=macro_averaging,
40-
)
41-
42-
metrics(y_true, logits)
43-
score = metrics.getmetrics()
44-
metrics.resetmetric()
45-
empty_score = metrics.getmetrics()
46-
47-
assert isinstance(score, dict), "Expected a dictionary output."
48-
assert metric in score, f"Expected {metric} metric in the output."
49-
assert score[metric] >= 0, "Expected a non-negative value."
50-
assert np.isnan(empty_score[metric]), "Expected an empty list."
51-
52-
53-
def test_recall():
54-
import torch
55-
56-
y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6])
57-
logits = torch.randn(7, 7)
58-
59-
recall_micro = Recall(7)
60-
recall_macro = Recall(7, macro_averaging=True)
61-
62-
recall_micro(y_true, logits)
63-
recall_macro(y_true, logits)
64-
65-
recall_micro_score = recall_micro.__returnmetric__()
66-
recall_macro_score = recall_macro.__returnmetric__()
67-
68-
assert isinstance(recall_micro_score, torch.Tensor), "Expected a tensor output."
69-
assert isinstance(recall_macro_score, torch.Tensor), "Expected a tensor output."
70-
assert recall_micro_score.item() >= 0, "Expected a non-negative value."
71-
assert recall_macro_score.item() >= 0, "Expected a non-negative value."
7+
METRICS = ["f1", "recall", "accuracy", "precision", "entropy"]
728

739

74-
def test_f1score():
75-
import torch
76-
77-
# Example case with known output
78-
y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels
79-
y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels
80-
81-
# Create F1Score object for micro and macro averaging
82-
f1_micro = F1Score(num_classes=3, macro_averaging=False)
83-
f1_macro = F1Score(num_classes=3, macro_averaging=True)
84-
85-
# Update F1 score with predictions
86-
f1_micro(y_true, y_pred)
87-
f1_macro(y_true, y_pred)
10+
def _metric_combinations():
11+
"""
12+
Yield various combinations of metrics:
13+
1. Single metric as a list
14+
2. Pairs of metrics
15+
3. All metrics
16+
"""
8817

89-
# Get F1 scores
90-
micro_f1_score = f1_micro.__returnmetric__()
91-
macro_f1_score = f1_macro.__returnmetric__()
92-
93-
# Check if outputs are tensors
94-
assert isinstance(micro_f1_score, torch.Tensor), (
95-
"Micro F1 score should be a tensor."
96-
)
97-
assert isinstance(macro_f1_score, torch.Tensor), (
98-
"Macro F1 score should be a tensor."
99-
)
18+
# Single metrics as lists
19+
for m in METRICS:
20+
yield [m]
10021

101-
# Check that F1 scores are between 0 and 1
102-
assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1."
103-
assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1."
22+
# Pairs of metrics (2-combinations)
23+
for combo in itertools.combinations(METRICS, 2):
24+
yield list(combo)
10425

105-
print(f"Micro F1 Score: {micro_f1_score.item()}")
106-
print(f"Macro F1 Score: {macro_f1_score.item()}")
26+
# Also test all metrics at once
27+
yield METRICS
10728

10829

109-
def test_precision():
30+
@pytest.mark.parametrize("metrics", _metric_combinations())
31+
@pytest.mark.parametrize("num_classes", [2, 3, 5, 10])
32+
@pytest.mark.parametrize("macro_averaging", [True, False])
33+
def test_metric_wrapper(metrics, num_classes, macro_averaging):
11034
import numpy as np
11135
import torch
112-
from sklearn.metrics import precision_score
11336

114-
C = randint(2, 10) # number of classes
115-
N = randint(2, 10 * C) # batchsize
116-
y_true = torch.randint(0, C, (N,))
117-
logits = torch.randn(N, C)
118-
119-
# create metric objects
120-
precision_micro = Precision(num_classes=C)
121-
precision_macro = Precision(num_classes=C, macro_averaging=True)
122-
123-
# run metric object
124-
precision_micro(y_true, logits)
125-
precision_macro(y_true, logits)
126-
127-
# get metric scores
128-
micro_precision_score = precision_micro.__returnmetric__()
129-
macro_precision_score = precision_macro.__returnmetric__()
130-
131-
# check output to be tensor
132-
assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected."
133-
assert isinstance(macro_precision_score, torch.Tensor), "Tensor output is expected."
134-
135-
# check for non-negativity
136-
assert micro_precision_score.item() >= 0, "Expected non-negative value"
137-
assert macro_precision_score.item() >= 0, "Expected non-negative value"
138-
139-
# find predictions
140-
y_pred = logits.argmax(dim=-1)
141-
142-
# check dimension
143-
assert y_true.shape == torch.Size([N])
144-
assert logits.shape == torch.Size([N, C])
145-
assert y_pred.shape == torch.Size([N])
146-
147-
# find true values with scikit learn
148-
scikit_macro_precision = precision_score(y_true, y_pred, average="macro")
149-
scikit_micro_precision = precision_score(y_true, y_pred, average="micro")
150-
151-
# check for similarity
152-
assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), (
153-
"Score does not match scikit's score"
154-
)
155-
assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), (
156-
"Score does not match scikit's score"
157-
)
158-
159-
160-
def test_accuracy():
161-
import numpy as np
162-
import torch
37+
y_true = torch.arange(num_classes, dtype=torch.int64)
38+
logits = torch.rand(num_classes, num_classes)
16339

164-
# Test the accuracy metric
165-
y_true = torch.tensor([0, 1, 2, 3, 4, 5])
166-
y_pred = torch.tensor([0, 1, 2, 3, 4, 5])
167-
accuracy = Accuracy(num_classes=6, macro_averaging=False)
168-
accuracy(y_true, y_pred)
169-
assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0"
170-
accuracy.__reset__()
171-
assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0"
172-
y_pred = torch.tensor([0, 1, 2, 3, 4, 4])
173-
accuracy(y_true, y_pred)
174-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
175-
"Expected accuracy to be 0.8333333134651184"
176-
)
177-
accuracy.__reset__()
178-
accuracy.macro_averaging = True
179-
accuracy(y_true, y_pred)
180-
y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5])
181-
y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4])
182-
accuracy(y_true_1, y_pred_1)
183-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
184-
"Expected accuracy to be 0.8333333134651186"
185-
)
186-
accuracy.macro_averaging = False
187-
assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, (
188-
"Expected accuracy to be 0.8333333134651184"
40+
mw = MetricWrapper(
41+
*metrics,
42+
num_classes=num_classes,
43+
macro_averaging=macro_averaging,
18944
)
190-
accuracy.__reset__()
19145

46+
mw(y_true, logits)
47+
score = mw.getmetrics()
48+
mw.resetmetric()
49+
empty_score = mw.getmetrics()
19250

193-
def test_entropypred():
194-
import torch as th
195-
196-
true_lab = th.rand(6, 5)
197-
198-
metric = EntropyPrediction(num_classes=5)
199-
200-
# Test if the metric stores multiple values
201-
pred_logits = th.rand(6, 5)
202-
metric(true_lab, pred_logits)
203-
204-
pred_logits = th.rand(6, 5)
205-
metric(true_lab, pred_logits)
206-
207-
pred_logits = th.rand(6, 5)
208-
metric(true_lab, pred_logits)
209-
210-
assert type(metric.__returnmetric__()) == th.Tensor
51+
assert isinstance(score, dict), "Expected a dictionary output."
52+
for m in metrics:
53+
assert m in score, f"Expected metric '{m}' in the output."
54+
assert score[m] >= 0, "Expected a non-negative value."
21155

212-
# Test than an error is raised with num_class != class dimension length
213-
with pytest.raises(AssertionError):
214-
metric(true_lab, th.rand(6, 6))
56+
assert m in empty_score, f"Expected metric '{m}' in the output."
57+
assert np.isnan(empty_score[m]), "Expected an empty list."

0 commit comments

Comments
 (0)