|
25 | 25 | from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars |
26 | 26 | from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin |
27 | 27 | from lightning.fabric.utilities.distributed import _distributed_available |
28 | | -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 |
| 28 | +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_0 |
29 | 29 | from lightning.pytorch.utilities.data import extract_batch_size |
30 | 30 | from lightning.pytorch.utilities.exceptions import MisconfigurationException |
31 | 31 | from lightning.pytorch.utilities.memory import recursive_detach |
@@ -112,7 +112,7 @@ class _Metadata: |
112 | 112 | on_step: bool = False |
113 | 113 | on_epoch: bool = True |
114 | 114 | # https://github.com/pytorch/pytorch/issues/96197 |
115 | | - reduce_fx: Callable = "mean" if _TORCH_GREATER_EQUAL_2_0 else torch.mean # type: ignore[assignment] |
| 115 | + reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment] |
116 | 116 | enable_graph: bool = False |
117 | 117 | add_dataloader_idx: bool = True |
118 | 118 | dataloader_idx: Optional[int] = None |
@@ -352,7 +352,7 @@ def log( |
352 | 352 | on_step: bool = False, |
353 | 353 | on_epoch: bool = True, |
354 | 354 | # https://github.com/pytorch/pytorch/issues/96197 |
355 | | - reduce_fx: Callable = "mean" if _TORCH_GREATER_EQUAL_2_0 else torch.mean, # type: ignore[assignment] |
| 355 | + reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment] |
356 | 356 | enable_graph: bool = False, |
357 | 357 | sync_dist: bool = False, |
358 | 358 | sync_dist_fn: Callable = _Sync.no_op, |
|
0 commit comments