Skip to content

Commit ff1b8b2

Browse files
committed
update
1 parent f212eab commit ff1b8b2

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

tests/tests_pytorch/deprecated_api/test_no_removal_version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unittest.mock import Mock
33

44
import pytest
5-
import pytorch_lightning_enterprise.utils.imports
5+
import pytorch_lightning_enterprise.plugins.precision.bitsandbytes
66
import torch.nn
77

88
import lightning.fabric
@@ -63,7 +63,7 @@ def test_fsdp_precision_plugin():
6363

6464
def test_bitsandbytes_precision_plugin(monkeypatch):
6565
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
66-
monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True)
66+
monkeypatch.setattr(pytorch_lightning_enterprise.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
6767
bitsandbytes_mock = Mock()
6868
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
6969

tests/tests_pytorch/graveyard/test_tpu.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from importlib import import_module
23

34
import pytest
@@ -39,3 +40,9 @@ def test_graveyard_no_device(import_path, name):
3940
cls = getattr(module, name)
4041
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
4142
cls()
43+
44+
# teardown
45+
# ideally, we should call the plugin's teardown method, but since the class
46+
# instantiation itself fails, we directly manipulate the env vars here
47+
os.environ.pop("XLA_USE_BF16", None)
48+
os.environ.pop("XLA_USE_F16", None)

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_mlflow_logger_experiment_calls_with_synchronous(mlflow_mock, tmp_path,
304304

305305

306306
@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._get_resolve_tags", Mock())
307-
@mock.patch("pytorch_lightning_enterprise.utils.imports._MLFLOW_SYNCHRONOUS_AVAILABLE", False)
307+
@mock.patch("pytorch_lightning_enterprise.loggers.mlflow._MLFLOW_SYNCHRONOUS_AVAILABLE", False)
308308
def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):
309309
"""Test that the logger does not support synchronous flag."""
310310
time = mlflow_mock.entities.time

0 commit comments

Comments
 (0)