Skip to content

Commit 8434ee7

Browse files
authored
Update dynamo bug workaround condition (#17065)
1 parent 5fe9e93 commit 8434ee7

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@
2929
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
3030
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)
3131
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
32+
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
2626
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
2727
from lightning.fabric.utilities.distributed import _distributed_available
28-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
28+
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0
2929
from lightning.pytorch.utilities.data import extract_batch_size
3030
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3131
from lightning.pytorch.utilities.memory import recursive_detach
@@ -112,7 +112,7 @@ class _Metadata:
112112
on_step: bool = False
113113
on_epoch: bool = True
114114
# https://github.com/pytorch/pytorch/issues/96197
115-
reduce_fx: Callable = "mean" if _TORCH_GREATER_EQUAL_2_0 else torch.mean # type: ignore[assignment]
115+
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment]
116116
enable_graph: bool = False
117117
add_dataloader_idx: bool = True
118118
dataloader_idx: Optional[int] = None
@@ -352,7 +352,7 @@ def log(
352352
on_step: bool = False,
353353
on_epoch: bool = True,
354354
# https://github.com/pytorch/pytorch/issues/96197
355-
reduce_fx: Callable = "mean" if _TORCH_GREATER_EQUAL_2_0 else torch.mean, # type: ignore[assignment]
355+
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment]
356356
enable_graph: bool = False,
357357
sync_dist: bool = False,
358358
sync_dist_fn: Callable = _Sync.no_op,

0 commit comments

Comments
 (0)