|
1 | | -from typing import Any, Callable, Tuple, Union |
| 1 | +from typing import Any, Callable, cast, Tuple, Union |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 |
|
| 5 | +from ignite import distributed as idist |
| 6 | +from ignite.exceptions import NotComputableError |
5 | 7 | from ignite.metrics import EpochMetric |
6 | 8 |
|
7 | 9 |
|
@@ -103,6 +105,8 @@ class RocCurve(EpochMetric): |
103 | 105 | <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html# |
104 | 106 | sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are |
105 | 107 | no issues. User will be warned in case there are any issues computing the function. |
| 108 | + device: optional device specification for internal storage. |
| 109 | +
|
106 | 110 | Note: |
107 | 111 | RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence |
108 | 112 | values. To apply an activation to y_pred, use output_transform as shown below: |
@@ -137,15 +141,56 @@ def sigmoid_output_transform(output): |
137 | 141 | FPR [0.0, 0.333, 0.333, 1.0] |
138 | 142 | TPR [0.0, 0.0, 1.0, 1.0] |
139 | 143 | Thresholds [2.0, 1.0, 0.711, 0.047] |
| 144 | +
|
| 145 | + .. versionchanged:: 0.4.11 |
| 146 | + added `device` argument |
140 | 147 | """ |
141 | 148 |
|
142 | | - def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None: |
| 149 | + def __init__( |
| 150 | + self, |
| 151 | + output_transform: Callable = lambda x: x, |
| 152 | + check_compute_fn: bool = False, |
| 153 | + device: Union[str, torch.device] = torch.device("cpu"), |
| 154 | + ) -> None: |
143 | 155 |
|
144 | 156 | try: |
145 | 157 | from sklearn.metrics import roc_curve # noqa: F401 |
146 | 158 | except ImportError: |
147 | 159 | raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") |
148 | 160 |
|
149 | 161 | super(RocCurve, self).__init__( |
150 | | - roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn |
| 162 | + roc_auc_curve_compute_fn, |
| 163 | + output_transform=output_transform, |
| 164 | + check_compute_fn=check_compute_fn, |
| 165 | + device=device, |
151 | 166 | ) |
| 167 | + |
| 168 | + def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 169 | + if len(self._predictions) < 1 or len(self._targets) < 1: |
| 170 | + raise NotComputableError("RocCurve must have at least one example before it can be computed.") |
| 171 | + |
| 172 | + _prediction_tensor = torch.cat(self._predictions, dim=0) |
| 173 | + _target_tensor = torch.cat(self._targets, dim=0) |
| 174 | + |
| 175 | + ws = idist.get_world_size() |
| 176 | + if ws > 1: |
| 177 | + # All gather across all processes |
| 178 | + _prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor)) |
| 179 | + _target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor)) |
| 180 | + |
| 181 | + if idist.get_rank() == 0: |
| 182 | + # Run compute_fn on zero rank only |
| 183 | + fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor) |
| 184 | + fpr = torch.tensor(fpr) |
| 185 | + tpr = torch.tensor(tpr) |
| 186 | + thresholds = torch.tensor(thresholds) |
| 187 | + else: |
| 188 | + fpr, tpr, thresholds = None, None, None |
| 189 | + |
| 190 | + if ws > 1: |
| 191 | + # broadcast result to all processes |
| 192 | + fpr = idist.broadcast(fpr, src=0, safe_mode=True) |
| 193 | + tpr = idist.broadcast(tpr, src=0, safe_mode=True) |
| 194 | + thresholds = idist.broadcast(thresholds, src=0, safe_mode=True) |
| 195 | + |
| 196 | + return fpr, tpr, thresholds |
0 commit comments