Skip to content

Commit f33b614

Browse files
Add distributed test to RocCurve metric (#2802)
* Add test * Override compute method * Use engine in test * Update ignite/contrib/metrics/roc_auc.py --------- Co-authored-by: vfdev <[email protected]>
1 parent ac3c11b commit f33b614

File tree

2 files changed

+93
-3
lines changed

2 files changed

+93
-3
lines changed

ignite/contrib/metrics/roc_auc.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Any, Callable, Tuple, Union
1+
from typing import Any, Callable, cast, Tuple, Union
22

33
import torch
44

5+
from ignite import distributed as idist
6+
from ignite.exceptions import NotComputableError
57
from ignite.metrics import EpochMetric
68

79

@@ -103,6 +105,8 @@ class RocCurve(EpochMetric):
103105
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
104106
sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
105107
no issues. User will be warned in case there are any issues computing the function.
108+
device: optional device specification for internal storage.
109+
106110
Note:
107111
RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
108112
values. To apply an activation to y_pred, use output_transform as shown below:
@@ -137,15 +141,56 @@ def sigmoid_output_transform(output):
137141
FPR [0.0, 0.333, 0.333, 1.0]
138142
TPR [0.0, 0.0, 1.0, 1.0]
139143
Thresholds [2.0, 1.0, 0.711, 0.047]
144+
145+
.. versionchanged:: 0.4.11
146+
added `device` argument
140147
"""
141148

142-
def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None:
149+
def __init__(
150+
self,
151+
output_transform: Callable = lambda x: x,
152+
check_compute_fn: bool = False,
153+
device: Union[str, torch.device] = torch.device("cpu"),
154+
) -> None:
143155

144156
try:
145157
from sklearn.metrics import roc_curve # noqa: F401
146158
except ImportError:
147159
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
148160

149161
super(RocCurve, self).__init__(
150-
roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
162+
roc_auc_curve_compute_fn,
163+
output_transform=output_transform,
164+
check_compute_fn=check_compute_fn,
165+
device=device,
151166
)
167+
168+
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169+
if len(self._predictions) < 1 or len(self._targets) < 1:
170+
raise NotComputableError("RocCurve must have at least one example before it can be computed.")
171+
172+
_prediction_tensor = torch.cat(self._predictions, dim=0)
173+
_target_tensor = torch.cat(self._targets, dim=0)
174+
175+
ws = idist.get_world_size()
176+
if ws > 1:
177+
# All gather across all processes
178+
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
179+
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
180+
181+
if idist.get_rank() == 0:
182+
# Run compute_fn on zero rank only
183+
fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
184+
fpr = torch.tensor(fpr)
185+
tpr = torch.tensor(tpr)
186+
thresholds = torch.tensor(thresholds)
187+
else:
188+
fpr, tpr, thresholds = None, None, None
189+
190+
if ws > 1:
191+
# broadcast result to all processes
192+
fpr = idist.broadcast(fpr, src=0, safe_mode=True)
193+
tpr = idist.broadcast(tpr, src=0, safe_mode=True)
194+
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)
195+
196+
return fpr, tpr, thresholds

tests/ignite/contrib/metrics/test_roc_curve.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
import torch
77
from sklearn.metrics import roc_curve
88

9+
from ignite import distributed as idist
910
from ignite.contrib.metrics.roc_auc import RocCurve
1011
from ignite.engine import Engine
12+
from ignite.exceptions import NotComputableError
1113
from ignite.metrics.epoch_metric import EpochMetricWarning
1214

1315

16+
def test_wrong_setup():
17+
def compute_fn(y_preds, y_targets):
18+
return 0.0
19+
20+
with pytest.raises(NotComputableError, match="RocCurve must have at least one example before it can be computed"):
21+
metric = RocCurve(compute_fn)
22+
metric.compute()
23+
24+
1425
@pytest.fixture()
1526
def mock_no_sklearn():
1627
with patch.dict("sys.modules", {"sklearn.metrics": None}):
@@ -121,3 +132,37 @@ def test_check_compute_fn():
121132

122133
em = RocCurve(check_compute_fn=False)
123134
em.update(output)
135+
136+
137+
def test_distrib_integration(distributed):
138+
rank = idist.get_rank()
139+
torch.manual_seed(41 + rank)
140+
n_batches, batch_size = 5, 10
141+
y = torch.randint(0, 2, size=(n_batches * batch_size,))
142+
y_pred = torch.rand((n_batches * batch_size,))
143+
144+
def update(engine, i):
145+
return (
146+
y_pred[i * batch_size : (i + 1) * batch_size],
147+
y[i * batch_size : (i + 1) * batch_size],
148+
)
149+
150+
engine = Engine(update)
151+
152+
device = "cpu" if idist.device().type == "xla" else idist.device()
153+
metric = RocCurve(device=device)
154+
metric.attach(engine, "roc_curve")
155+
156+
data = list(range(n_batches))
157+
158+
engine.run(data=data, max_epochs=1)
159+
160+
fpr, tpr, thresholds = engine.state.metrics["roc_curve"]
161+
162+
y = idist.all_gather(y)
163+
y_pred = idist.all_gather(y_pred)
164+
sk_fpr, sk_tpr, sk_thresholds = roc_curve(y, y_pred)
165+
166+
assert np.array_equal(fpr, sk_fpr)
167+
assert np.array_equal(tpr, sk_tpr)
168+
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

0 commit comments

Comments
 (0)