Skip to content

Commit 859cbfd

Browse files
committed
fix loggers error
1 parent ff1b8b2 commit 859cbfd

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

tests/tests_pytorch/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ def xla_available(monkeypatch: pytest.MonkeyPatch) -> None:
265265
mock_xla_available(monkeypatch)
266266

267267

268+
@pytest.fixture
269+
def xla_not_available(monkeypatch: pytest.MonkeyPatch) -> None:
270+
mock_xla_available(monkeypatch, False)
271+
272+
268273
def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None:
269274
mock_xla_available(monkeypatch, value)
270275
monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: value)

tests/tests_pytorch/graveyard/test_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_graveyard_single_tpu(import_path, name):
3535
("lightning.pytorch.plugins.precision.xlabf16", "XLABf16PrecisionPlugin"),
3636
],
3737
)
38-
def test_graveyard_no_device(import_path, name):
38+
def test_graveyard_no_device(import_path, name, xla_not_available):
3939
module = import_module(import_path)
4040
cls = getattr(module, name)
4141
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):

tests/tests_pytorch/loggers/test_all.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ def log_metrics(self, metrics, step):
107107

108108
if logger_class == CometLogger:
109109
logger.experiment.id = "foo"
110-
logger._comet_config.offline_directory = None
111-
logger._project_name = "bar"
110+
# TODO: Verify with @justusschock if this is accepted approach to test experiment creation
111+
logger.logger_impl._comet_config.offline_directory = None
112+
logger.logger_impl._project_name = "bar"
112113
logger.experiment.get_key.return_value = "SOME_KEY"
113114

114115
if logger_class == NeptuneLogger:
@@ -358,6 +359,7 @@ def test_logger_default_name(mlflow_mock, monkeypatch, tmp_path):
358359
logger = _instantiate_logger(MLFlowLogger, save_dir=tmp_path)
359360

360361
_ = logger.experiment
361-
logger._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
362+
# TODO: Verify with @justusschock if this is accepted approach to test experiment creation
363+
logger.logger_impl._mlflow_client.create_experiment.assert_called_with(name="lightning_logs", artifact_location=ANY)
362364
# on MLFLowLogger `name` refers to the experiment id
363365
# assert logger.experiment.get_experiment(logger.name).name == "lightning_logs"

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest.mock import Mock
2121

2222
import pytest
23+
import pytorch_lightning_enterprise.plugins.precision.bitsandbytes
2324
import torch
2425
import torch.distributed
2526

@@ -968,6 +969,7 @@ def test_precision_selection(precision_str, strategy_str, expected_precision_cls
968969

969970
def test_bitsandbytes_precision_cuda_required(monkeypatch):
970971
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
972+
monkeypatch.setattr(pytorch_lightning_enterprise.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
971973
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
972974
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
973975
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))

0 commit comments

Comments
 (0)