Skip to content

Commit fdae213

Browse files
use new update_called from metrics (#18714)
Co-authored-by: awaelchli <[email protected]>
1 parent 74d4020 commit fdae213

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
2727
from lightning.pytorch.utilities.data import extract_batch_size
2828
from lightning.pytorch.utilities.exceptions import MisconfigurationException
29+
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
2930
from lightning.pytorch.utilities.memory import recursive_detach
3031
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
3132
from lightning.pytorch.utilities.warnings import PossibleUserWarning
@@ -265,7 +266,8 @@ def _wrap_compute(self, compute: Any) -> Any:
265266
# Override to avoid syncing - we handle it ourselves.
266267
@wraps(compute)
267268
def wrapped_func(*args: Any, **kwargs: Any) -> Optional[Any]:
268-
if not self._update_called:
269+
update_called = self.update_called if _TORCHMETRICS_GREATER_EQUAL_1_0_0 else self._update_called
270+
if not update_called:
269271
rank_zero_warn(
270272
f"The ``compute`` method of metric {self.__class__.__name__}"
271273
" was called before the ``update`` method which may lead to errors,"

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
2222
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2323
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
24+
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
2425

2526
_OMEGACONF_AVAILABLE = package_available("omegaconf")
2627
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")

0 commit comments

Comments
 (0)