Skip to content

Commit 53d641d

Browse files
kzkadcvfdev-5
andauthored
Fix device compatibility for PearsonCorrelation metric (#3223)
* add PearsonCorrelation metric * match the notation of the docstring with the other metrics * move PearsonCorrelation metric from contrib.metrics.regression to metrics.regression * update test for PearsonCorrelation metric * update test * modify doc for PearsonCorrelation metric * fix import * resolve code formatting issue * remove loop from test * Update ignite/metrics/regression/pearson_correlation.py Co-authored-by: vfdev <[email protected]> * Update pearson_correlation.py * update test for PearsonCorrelation * Update tests/ignite/metrics/regression/test_pearson_correlation.py * relax pytest.approx * fix device compatibility --------- Co-authored-by: vfdev <[email protected]>
1 parent 0dfeabe commit 53d641d

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ignite/metrics/regression/pearson_correlation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def reset(self) -> None:
8787

8888
def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
8989
y_pred, y = output[0].detach(), output[1].detach()
90-
self._sum_of_y_preds += y_pred.sum()
91-
self._sum_of_ys += y.sum()
92-
self._sum_of_y_pred_squares += y_pred.square().sum()
93-
self._sum_of_y_squares += y.square().sum()
94-
self._sum_of_products += (y_pred * y).sum()
90+
self._sum_of_y_preds += y_pred.sum().to(self._device)
91+
self._sum_of_ys += y.sum().to(self._device)
92+
self._sum_of_y_pred_squares += y_pred.square().sum().to(self._device)
93+
self._sum_of_y_squares += y.square().sum().to(self._device)
94+
self._sum_of_products += (y_pred * y).sum().to(self._device)
9595
self._num_examples += y.shape[0]
9696

9797
@sync_all_reduce(

0 commit comments

Comments
 (0)