Skip to content

Commit 5d8d6bf

Browse files
Refactor EpochMetric and make it idempotent (#2800)
* Add failing assertion * Fix compute method * Fix bugs and Mypy errors * Fix remaining mypy errors * Fix some bugs --------- Co-authored-by: vfdev <[email protected]>
1 parent c7c0df0 commit 5d8d6bf

File tree

5 files changed

+65
-110
lines changed

5 files changed

+65
-110
lines changed

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -78,41 +78,43 @@ def __init__(
7878
device: Union[str, torch.device] = torch.device("cpu"),
7979
) -> None:
8080
super(PrecisionRecallCurve, self).__init__(
81-
precision_recall_curve_compute_fn,
81+
precision_recall_curve_compute_fn, # type: ignore[arg-type]
8282
output_transform=output_transform,
8383
check_compute_fn=check_compute_fn,
8484
device=device,
8585
)
8686

87-
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
87+
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
8888
if len(self._predictions) < 1 or len(self._targets) < 1:
8989
raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")
9090

91-
_prediction_tensor = torch.cat(self._predictions, dim=0)
92-
_target_tensor = torch.cat(self._targets, dim=0)
93-
94-
ws = idist.get_world_size()
95-
if ws > 1 and not self._is_reduced:
96-
# All gather across all processes
97-
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
98-
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
99-
self._is_reduced = True
100-
101-
if idist.get_rank() == 0:
102-
# Run compute_fn on zero rank only
103-
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
104-
precision = torch.tensor(precision)
105-
recall = torch.tensor(recall)
106-
# thresholds can have negative strides, not compatible with torch tensors
107-
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
108-
thresholds = torch.tensor(thresholds.copy())
109-
else:
110-
precision, recall, thresholds = None, None, None
111-
112-
if ws > 1:
113-
# broadcast result to all processes
114-
precision = idist.broadcast(precision, src=0, safe_mode=True)
115-
recall = idist.broadcast(recall, src=0, safe_mode=True)
116-
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)
117-
118-
return precision, recall, thresholds
91+
if self._result is None:
92+
_prediction_tensor = torch.cat(self._predictions, dim=0)
93+
_target_tensor = torch.cat(self._targets, dim=0)
94+
95+
ws = idist.get_world_size()
96+
if ws > 1:
97+
# All gather across all processes
98+
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
99+
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
100+
101+
if idist.get_rank() == 0:
102+
# Run compute_fn on zero rank only
103+
precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
104+
precision = torch.tensor(precision, device=_prediction_tensor.device)
105+
recall = torch.tensor(recall, device=_prediction_tensor.device)
106+
# thresholds can have negative strides, not compatible with torch tensors
107+
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
108+
thresholds = torch.tensor(thresholds.copy(), device=_prediction_tensor.device)
109+
else:
110+
precision, recall, thresholds = None, None, None
111+
112+
if ws > 1:
113+
# broadcast result to all processes
114+
precision = idist.broadcast(precision, src=0, safe_mode=True)
115+
recall = idist.broadcast(recall, src=0, safe_mode=True)
116+
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)
117+
118+
self._result = (precision, recall, thresholds) # type: ignore[assignment]
119+
120+
return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)

ignite/contrib/metrics/roc_auc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,13 @@ def __init__(
159159
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")
160160

161161
super(RocCurve, self).__init__(
162-
roc_auc_curve_compute_fn,
162+
roc_auc_curve_compute_fn, # type: ignore[arg-type]
163163
output_transform=output_transform,
164164
check_compute_fn=check_compute_fn,
165165
device=device,
166166
)
167167

168-
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
168+
def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
169169
if len(self._predictions) < 1 or len(self._targets) < 1:
170170
raise NotComputableError("RocCurve must have at least one example before it can be computed.")
171171

@@ -180,7 +180,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180180

181181
if idist.get_rank() == 0:
182182
# Run compute_fn on zero rank only
183-
fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
183+
fpr, tpr, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor))
184184
fpr = torch.tensor(fpr, device=_prediction_tensor.device)
185185
tpr = torch.tensor(tpr, device=_prediction_tensor.device)
186186
thresholds = torch.tensor(thresholds, device=_prediction_tensor.device)

ignite/metrics/epoch_metric.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Callable, cast, List, Tuple, Union
2+
from typing import Callable, cast, List, Optional, Tuple, Union
33

44
import torch
55

@@ -28,9 +28,8 @@ class EpochMetric(Metric):
2828
- ``update`` must receive output of the form ``(y_pred, y)``.
2929
3030
Args:
31-
compute_fn: a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input
32-
`predictions` and `targets` and returns a scalar. Input tensors will be on specified ``device``
33-
(see arg below).
31+
compute_fn: a callable which receives two tensors as the `predictions` and `targets`
32+
and returns a scalar. Input tensors will be on specified ``device`` (see arg below).
3433
output_transform: a callable that is used to transform the
3534
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
3635
form expected by the metric. This can be useful if, for example, you have a multi-output model and
@@ -70,7 +69,7 @@ def mse_fn(y_preds, y_targets):
7069

7170
def __init__(
7271
self,
73-
compute_fn: Callable,
72+
compute_fn: Callable[[torch.Tensor, torch.Tensor], float],
7473
output_transform: Callable = lambda x: x,
7574
check_compute_fn: bool = True,
7675
device: Union[str, torch.device] = torch.device("cpu"),
@@ -88,6 +87,7 @@ def __init__(
8887
def reset(self) -> None:
8988
self._predictions: List[torch.Tensor] = []
9089
self._targets: List[torch.Tensor] = []
90+
self._result: Optional[float] = None
9191

9292
def _check_shape(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
9393
y_pred, y = output
@@ -136,31 +136,30 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
136136
except Exception as e:
137137
warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)
138138

139-
def compute(self) -> Any:
139+
def compute(self) -> float:
140140
if len(self._predictions) < 1 or len(self._targets) < 1:
141141
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
142142

143-
_prediction_tensor = torch.cat(self._predictions, dim=0)
144-
_target_tensor = torch.cat(self._targets, dim=0)
143+
if self._result is None:
144+
_prediction_tensor = torch.cat(self._predictions, dim=0)
145+
_target_tensor = torch.cat(self._targets, dim=0)
145146

146-
ws = idist.get_world_size()
147+
ws = idist.get_world_size()
148+
if ws > 1:
149+
# All gather across all processes
150+
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
151+
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
147152

148-
if ws > 1 and not self._is_reduced:
149-
# All gather across all processes
150-
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
151-
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
152-
self._is_reduced = True
153+
self._result = 0.0
154+
if idist.get_rank() == 0:
155+
# Run compute_fn on zero rank only
156+
self._result = self.compute_fn(_prediction_tensor, _target_tensor)
153157

154-
result = 0.0
155-
if idist.get_rank() == 0:
156-
# Run compute_fn on zero rank only
157-
result = self.compute_fn(_prediction_tensor, _target_tensor)
158+
if ws > 1:
159+
# broadcast result to all processes
160+
self._result = cast(float, idist.broadcast(self._result, src=0))
158161

159-
if ws > 1:
160-
# broadcast result to all processes
161-
result = cast(float, idist.broadcast(result, src=0))
162-
163-
return result
162+
return self._result
164163

165164

166165
class EpochMetricWarning(UserWarning):

ignite/metrics/metric.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,8 @@ def reinit__is_reduced(func: Callable) -> Callable:
595595
def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None:
596596
func(self, *args, **kwargs)
597597
self._is_reduced = False
598+
if "_result" in self.__dict__:
599+
self._result = None # type: ignore[attr-defined]
598600

599601
setattr(wrapper, "_decorated", True)
600602
return wrapper

tests/ignite/metrics/test_epoch_metric.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import os
2-
31
import pytest
42
import torch
53

@@ -153,13 +151,11 @@ def compute_fn(y_preds, y_targets):
153151
em.update(output1)
154152

155153

156-
def _test_distrib_integration(device=None):
157-
158-
if device is None:
159-
device = idist.device() if idist.device().type != "xla" else "cpu"
154+
def test_distrib_integration(distributed):
160155

156+
device = idist.device() if idist.device().type != "xla" else "cpu"
161157
rank = idist.get_rank()
162-
torch.manual_seed(12 + rank)
158+
torch.manual_seed(40 + rank)
163159

164160
n_iters = 3
165161
batch_size = 2
@@ -188,51 +184,7 @@ def assert_data_fn(all_preds, all_targets):
188184

189185
y_preds = idist.all_gather(y_preds)
190186
y_true = idist.all_gather(y_true)
187+
ep_metric_true = (y_preds.argmax(dim=1) == y_true).sum().item()
191188

192-
assert engine.state.metrics["epm"] == (y_preds.argmax(dim=1) == y_true).sum().item()
193-
194-
195-
@pytest.mark.distributed
196-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
197-
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
198-
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
199-
200-
device = idist.device()
201-
_test_distrib_integration(device)
202-
203-
204-
@pytest.mark.distributed
205-
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
206-
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
207-
208-
device = idist.device()
209-
_test_distrib_integration(device)
210-
211-
212-
@pytest.mark.tpu
213-
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
214-
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
215-
def test_distrib_single_device_xla():
216-
_test_distrib_integration()
217-
218-
219-
def _test_distrib_xla_nprocs(index):
220-
_test_distrib_integration()
221-
222-
223-
@pytest.mark.tpu
224-
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
225-
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
226-
def test_distrib_xla_nprocs(xmp_executor):
227-
n = int(os.environ["NUM_TPU_WORKERS"])
228-
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)
229-
230-
231-
@pytest.mark.distributed
232-
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
233-
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
234-
def test_distrib_hvd(gloo_hvd_executor):
235-
236-
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()
237-
238-
gloo_hvd_executor(_test_distrib_integration, (None,), np=nproc, do_init=True)
189+
assert engine.state.metrics["epm"] == ep_metric_true
190+
assert ep_metric.compute() == ep_metric_true

0 commit comments

Comments
 (0)