Skip to content

Commit 51ba0a9

Browse files
authored
Fixes gpu tests failures in RocCurve by 2802 (#2867)
* Update roc_auc.py * Fixed issues
1 parent a9104a7 commit 51ba0a9

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

ignite/contrib/metrics/roc_auc.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def roc_auc_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float:
1818
def roc_auc_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
1919
from sklearn.metrics import roc_curve
2020

21-
y_true = y_targets.numpy()
22-
y_pred = y_preds.numpy()
21+
y_true = y_targets.cpu().numpy()
22+
y_pred = y_preds.cpu().numpy()
2323
return roc_curve(y_true, y_pred)
2424

2525

@@ -181,9 +181,9 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181181
if idist.get_rank() == 0:
182182
# Run compute_fn on zero rank only
183183
fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
184-
fpr = torch.tensor(fpr)
185-
tpr = torch.tensor(tpr)
186-
thresholds = torch.tensor(thresholds)
184+
fpr = torch.tensor(fpr, device=_prediction_tensor.device)
185+
tpr = torch.tensor(tpr, device=_prediction_tensor.device)
186+
thresholds = torch.tensor(thresholds, device=_prediction_tensor.device)
187187
else:
188188
fpr, tpr, thresholds = None, None, None
189189

tests/ignite/contrib/metrics/test_roc_curve.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def update(engine, i):
149149

150150
engine = Engine(update)
151151

152-
device = "cpu" if idist.device().type == "xla" else idist.device()
152+
device = torch.device("cpu") if idist.device().type == "xla" else idist.device()
153153
metric = RocCurve(device=device)
154154
metric.attach(engine, "roc_curve")
155155

@@ -159,10 +159,14 @@ def update(engine, i):
159159

160160
fpr, tpr, thresholds = engine.state.metrics["roc_curve"]
161161

162+
assert isinstance(fpr, torch.Tensor) and fpr.device == device
163+
assert isinstance(tpr, torch.Tensor) and tpr.device == device
164+
assert isinstance(thresholds, torch.Tensor) and thresholds.device == device
165+
162166
y = idist.all_gather(y)
163167
y_pred = idist.all_gather(y_pred)
164-
sk_fpr, sk_tpr, sk_thresholds = roc_curve(y, y_pred)
168+
sk_fpr, sk_tpr, sk_thresholds = roc_curve(y.cpu().numpy(), y_pred.cpu().numpy())
165169

166-
assert np.array_equal(fpr, sk_fpr)
167-
assert np.array_equal(tpr, sk_tpr)
168-
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
170+
np.testing.assert_array_almost_equal(fpr.cpu().numpy(), sk_fpr)
171+
np.testing.assert_array_almost_equal(tpr.cpu().numpy(), sk_tpr)
172+
np.testing.assert_array_almost_equal(thresholds.cpu().numpy(), sk_thresholds)

0 commit comments

Comments
 (0)