|
1 | | -from random import randint |
| 1 | +import itertools |
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | from CollaborativeCoding.load_metric import MetricWrapper |
6 | | -from CollaborativeCoding.metrics import ( |
7 | | - Accuracy, |
8 | | - EntropyPrediction, |
9 | | - F1Score, |
10 | | - Precision, |
11 | | - Recall, |
12 | | -) |
13 | 6 |
|
14 | | - |
15 | | -@pytest.mark.parametrize( |
16 | | - "metric, num_classes, macro_averaging", |
17 | | - [ |
18 | | - ("f1", randint(2, 10), False), |
19 | | - ("f1", randint(2, 10), True), |
20 | | - ("recall", randint(2, 10), False), |
21 | | - ("recall", randint(2, 10), True), |
22 | | - ("accuracy", randint(2, 10), False), |
23 | | - ("accuracy", randint(2, 10), True), |
24 | | - ("precision", randint(2, 10), False), |
25 | | - ("precision", randint(2, 10), True), |
26 | | - ("entropy", randint(2, 10), False), |
27 | | - ], |
28 | | -) |
29 | | -def test_metric_wrapper(metric, num_classes, macro_averaging): |
30 | | - import numpy as np |
31 | | - import torch |
32 | | - |
33 | | - y_true = torch.arange(num_classes, dtype=torch.int64) |
34 | | - logits = torch.rand(num_classes, num_classes) |
35 | | - |
36 | | - metrics = MetricWrapper( |
37 | | - metric, |
38 | | - num_classes=num_classes, |
39 | | - macro_averaging=macro_averaging, |
40 | | - ) |
41 | | - |
42 | | - metrics(y_true, logits) |
43 | | - score = metrics.getmetrics() |
44 | | - metrics.resetmetric() |
45 | | - empty_score = metrics.getmetrics() |
46 | | - |
47 | | - assert isinstance(score, dict), "Expected a dictionary output." |
48 | | - assert metric in score, f"Expected {metric} metric in the output." |
49 | | - assert score[metric] >= 0, "Expected a non-negative value." |
50 | | - assert np.isnan(empty_score[metric]), "Expected an empty list." |
51 | | - |
52 | | - |
53 | | -def test_recall(): |
54 | | - import torch |
55 | | - |
56 | | - y_true = torch.tensor([0, 1, 2, 3, 4, 5, 6]) |
57 | | - logits = torch.randn(7, 7) |
58 | | - |
59 | | - recall_micro = Recall(7) |
60 | | - recall_macro = Recall(7, macro_averaging=True) |
61 | | - |
62 | | - recall_micro(y_true, logits) |
63 | | - recall_macro(y_true, logits) |
64 | | - |
65 | | - recall_micro_score = recall_micro.__returnmetric__() |
66 | | - recall_macro_score = recall_macro.__returnmetric__() |
67 | | - |
68 | | - assert isinstance(recall_micro_score, torch.Tensor), "Expected a tensor output." |
69 | | - assert isinstance(recall_macro_score, torch.Tensor), "Expected a tensor output." |
70 | | - assert recall_micro_score.item() >= 0, "Expected a non-negative value." |
71 | | - assert recall_macro_score.item() >= 0, "Expected a non-negative value." |
| 7 | +METRICS = ["f1", "recall", "accuracy", "precision", "entropy"] |
72 | 8 |
|
73 | 9 |
|
74 | | -def test_f1score(): |
75 | | - import torch |
76 | | - |
77 | | - # Example case with known output |
78 | | - y_true = torch.tensor([0, 1, 2, 2, 1, 0]) # True labels |
79 | | - y_pred = torch.tensor([0, 1, 1, 2, 0, 0]) # Predicted labels |
80 | | - |
81 | | - # Create F1Score object for micro and macro averaging |
82 | | - f1_micro = F1Score(num_classes=3, macro_averaging=False) |
83 | | - f1_macro = F1Score(num_classes=3, macro_averaging=True) |
84 | | - |
85 | | - # Update F1 score with predictions |
86 | | - f1_micro(y_true, y_pred) |
87 | | - f1_macro(y_true, y_pred) |
| 10 | +def _metric_combinations(): |
| 11 | + """ |
| 12 | + Yield various combinations of metrics: |
| 13 | + 1. Single metric as a list |
| 14 | + 2. Pairs of metrics |
| 15 | + 3. All metrics |
| 16 | + """ |
88 | 17 |
|
89 | | - # Get F1 scores |
90 | | - micro_f1_score = f1_micro.__returnmetric__() |
91 | | - macro_f1_score = f1_macro.__returnmetric__() |
92 | | - |
93 | | - # Check if outputs are tensors |
94 | | - assert isinstance(micro_f1_score, torch.Tensor), ( |
95 | | - "Micro F1 score should be a tensor." |
96 | | - ) |
97 | | - assert isinstance(macro_f1_score, torch.Tensor), ( |
98 | | - "Macro F1 score should be a tensor." |
99 | | - ) |
| 18 | + # Single metrics as lists |
| 19 | + for m in METRICS: |
| 20 | + yield [m] |
100 | 21 |
|
101 | | - # Check that F1 scores are between 0 and 1 |
102 | | - assert 0 <= micro_f1_score.item() <= 1, "Micro F1 score should be between 0 and 1." |
103 | | - assert 0 <= macro_f1_score.item() <= 1, "Macro F1 score should be between 0 and 1." |
| 22 | + # Pairs of metrics (2-combinations) |
| 23 | + for combo in itertools.combinations(METRICS, 2): |
| 24 | + yield list(combo) |
104 | 25 |
|
105 | | - print(f"Micro F1 Score: {micro_f1_score.item()}") |
106 | | - print(f"Macro F1 Score: {macro_f1_score.item()}") |
| 26 | + # Also test all metrics at once |
| 27 | + yield METRICS |
107 | 28 |
|
108 | 29 |
|
109 | | -def test_precision(): |
| 30 | +@pytest.mark.parametrize("metrics", _metric_combinations()) |
| 31 | +@pytest.mark.parametrize("num_classes", [2, 3, 5, 10]) |
| 32 | +@pytest.mark.parametrize("macro_averaging", [True, False]) |
| 33 | +def test_metric_wrapper(metrics, num_classes, macro_averaging): |
110 | 34 | import numpy as np |
111 | 35 | import torch |
112 | | - from sklearn.metrics import precision_score |
113 | 36 |
|
114 | | - C = randint(2, 10) # number of classes |
115 | | - N = randint(2, 10 * C) # batchsize |
116 | | - y_true = torch.randint(0, C, (N,)) |
117 | | - logits = torch.randn(N, C) |
118 | | - |
119 | | - # create metric objects |
120 | | - precision_micro = Precision(num_classes=C) |
121 | | - precision_macro = Precision(num_classes=C, macro_averaging=True) |
122 | | - |
123 | | - # run metric object |
124 | | - precision_micro(y_true, logits) |
125 | | - precision_macro(y_true, logits) |
126 | | - |
127 | | - # get metric scores |
128 | | - micro_precision_score = precision_micro.__returnmetric__() |
129 | | - macro_precision_score = precision_macro.__returnmetric__() |
130 | | - |
131 | | - # check output to be tensor |
132 | | - assert isinstance(micro_precision_score, torch.Tensor), "Tensor output is expected." |
133 | | - assert isinstance(macro_precision_score, torch.Tensor), "Tensor output is expected." |
134 | | - |
135 | | - # check for non-negativity |
136 | | - assert micro_precision_score.item() >= 0, "Expected non-negative value" |
137 | | - assert macro_precision_score.item() >= 0, "Expected non-negative value" |
138 | | - |
139 | | - # find predictions |
140 | | - y_pred = logits.argmax(dim=-1) |
141 | | - |
142 | | - # check dimension |
143 | | - assert y_true.shape == torch.Size([N]) |
144 | | - assert logits.shape == torch.Size([N, C]) |
145 | | - assert y_pred.shape == torch.Size([N]) |
146 | | - |
147 | | - # find true values with scikit learn |
148 | | - scikit_macro_precision = precision_score(y_true, y_pred, average="macro") |
149 | | - scikit_micro_precision = precision_score(y_true, y_pred, average="micro") |
150 | | - |
151 | | - # check for similarity |
152 | | - assert np.isclose(scikit_micro_precision, micro_precision_score, atol=1e-5), ( |
153 | | - "Score does not match scikit's score" |
154 | | - ) |
155 | | - assert np.isclose(scikit_macro_precision, macro_precision_score, atol=1e-5), ( |
156 | | - "Score does not match scikit's score" |
157 | | - ) |
158 | | - |
159 | | - |
160 | | -def test_accuracy(): |
161 | | - import numpy as np |
162 | | - import torch |
| 37 | + y_true = torch.arange(num_classes, dtype=torch.int64) |
| 38 | + logits = torch.rand(num_classes, num_classes) |
163 | 39 |
|
164 | | - # Test the accuracy metric |
165 | | - y_true = torch.tensor([0, 1, 2, 3, 4, 5]) |
166 | | - y_pred = torch.tensor([0, 1, 2, 3, 4, 5]) |
167 | | - accuracy = Accuracy(num_classes=6, macro_averaging=False) |
168 | | - accuracy(y_true, y_pred) |
169 | | - assert accuracy.__returnmetric__() == 1.0, "Expected accuracy to be 1.0" |
170 | | - accuracy.__reset__() |
171 | | - assert accuracy.__returnmetric__() is np.nan, "Expected accuracy to be 0.0" |
172 | | - y_pred = torch.tensor([0, 1, 2, 3, 4, 4]) |
173 | | - accuracy(y_true, y_pred) |
174 | | - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( |
175 | | - "Expected accuracy to be 0.8333333134651184" |
176 | | - ) |
177 | | - accuracy.__reset__() |
178 | | - accuracy.macro_averaging = True |
179 | | - accuracy(y_true, y_pred) |
180 | | - y_true_1 = torch.tensor([0, 1, 2, 3, 4, 5]) |
181 | | - y_pred_1 = torch.tensor([0, 1, 2, 3, 4, 4]) |
182 | | - accuracy(y_true_1, y_pred_1) |
183 | | - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( |
184 | | - "Expected accuracy to be 0.8333333134651186" |
185 | | - ) |
186 | | - accuracy.macro_averaging = False |
187 | | - assert np.abs(accuracy.__returnmetric__() - 0.8333333134651184) < 1e-5, ( |
188 | | - "Expected accuracy to be 0.8333333134651184" |
| 40 | + mw = MetricWrapper( |
| 41 | + *metrics, |
| 42 | + num_classes=num_classes, |
| 43 | + macro_averaging=macro_averaging, |
189 | 44 | ) |
190 | | - accuracy.__reset__() |
191 | 45 |
|
| 46 | + mw(y_true, logits) |
| 47 | + score = mw.getmetrics() |
| 48 | + mw.resetmetric() |
| 49 | + empty_score = mw.getmetrics() |
192 | 50 |
|
193 | | -def test_entropypred(): |
194 | | - import torch as th |
195 | | - |
196 | | - true_lab = th.rand(6, 5) |
197 | | - |
198 | | - metric = EntropyPrediction(num_classes=5) |
199 | | - |
200 | | - # Test if the metric stores multiple values |
201 | | - pred_logits = th.rand(6, 5) |
202 | | - metric(true_lab, pred_logits) |
203 | | - |
204 | | - pred_logits = th.rand(6, 5) |
205 | | - metric(true_lab, pred_logits) |
206 | | - |
207 | | - pred_logits = th.rand(6, 5) |
208 | | - metric(true_lab, pred_logits) |
209 | | - |
210 | | - assert type(metric.__returnmetric__()) == th.Tensor |
| 51 | + assert isinstance(score, dict), "Expected a dictionary output." |
| 52 | + for m in metrics: |
| 53 | + assert m in score, f"Expected metric '{m}' in the output." |
| 54 | + assert score[m] >= 0, "Expected a non-negative value." |
211 | 55 |
|
212 | | - # Test than an error is raised with num_class != class dimension length |
213 | | - with pytest.raises(AssertionError): |
214 | | - metric(true_lab, th.rand(6, 6)) |
| 56 | + assert m in empty_score, f"Expected metric '{m}' in the output." |
| 57 | + assert np.isnan(empty_score[m]), "Expected an empty list." |
0 commit comments