Skip to content

Commit bc33500

Browse files
committed
update
1 parent 555f531 commit bc33500

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@
4141
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")
4242
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
4343

44+
_WANDB_AVAILABLE = RequirementCache("wandb>=0.12.10")
45+
_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4")
46+
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0")
47+
_MLFLOW_SYNCHRONOUS_AVAILABLE = RequirementCache("mlflow>=2.8.0")
48+
_NEPTUNE_AVAILABLE = RequirementCache("neptune>=1.0")
49+
4450
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
4551
_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
4652
_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise")

src/lightning/pytorch/plugins/precision/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class DeepSpeedPrecision(Precision):
4343
"""
4444

4545
def __init__(self, precision: _PRECISION_INPUT) -> None:
46-
super().__init__(precision)
46+
super().__init__()
4747
_raise_enterprise_not_available()
4848
from pytorch_lightning_enterprise.plugins.precision.deepspeed import (
4949
DeepSpeedPrecisionTrainer as EnterpriseDeepSpeedPrecision,

src/lightning/pytorch/utilities/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import torch
2222

23+
from lightning.fabric.utilities.imports import _DEEPSPEED_AVAILABLE
2324
from lightning.fabric.utilities.types import _PATH
24-
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE
2525

2626
CPU_DEVICE = torch.device("cpu")
2727

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
import pytest
1919

20+
from lightning.fabric.utilities.imports import _MLFLOW_AVAILABLE
2021
from lightning.pytorch import Trainer
2122
from lightning.pytorch.demos.boring_classes import BoringModel
2223
from lightning.pytorch.loggers.mlflow import (
23-
_MLFLOW_AVAILABLE,
2424
MLFlowLogger,
25-
_get_resolve_tags,
2625
)
2726

2827

@@ -96,6 +95,7 @@ def test_mlflow_run_name_setting(tmp_path):
9695
pytest.skip("test for explicit file creation requires mlflow dependency to be installed.")
9796

9897
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
98+
from pytorch_lightning_enterprise.loggers.mlflow import _get_resolve_tags
9999

100100
resolve_tags = _get_resolve_tags()
101101
tags = resolve_tags({MLFLOW_RUN_NAME: "run-name-1"})

0 commit comments

Comments
 (0)