Skip to content

Commit 1d683e6

Browse files
committed
update
1 parent 1eb83ac commit 1d683e6

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

tests/tests_fabric/graveyard/test_tpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def test_graveyard_single_tpu(import_path, name, tpu_available):
1818
with pytest.deprecated_call(match="is deprecated"):
1919
cls(device)
2020

21+
# required to prevent env-var leakage
22+
if hasattr(cls, "teardown"):
23+
cls.teardown()
24+
2125

2226
@pytest.mark.parametrize(
2327
("import_path", "name"),
@@ -39,3 +43,7 @@ def test_graveyard_no_device(import_path, name, tpu_available):
3943
cls = getattr(module, name)
4044
with pytest.deprecated_call(match="is deprecated"):
4145
cls()
46+
47+
# required to prevent env-var leakage
48+
if hasattr(cls, "teardown"):
49+
cls.teardown()

tests/tests_pytorch/graveyard/test_tpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ def test_graveyard_single_tpu(import_path, name):
1818
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
1919
cls(device)
2020

21+
# required to prevent env-var leakage
22+
if hasattr(cls, "teardown"):
23+
cls.teardown()
24+
2125

2226
@pytest.mark.parametrize(
2327
("import_path", "name"),
@@ -39,3 +43,7 @@ def test_graveyard_no_device(import_path, name):
3943
cls = getattr(module, name)
4044
with pytest.deprecated_call(match="is deprecated"), pytest.raises(ModuleNotFoundError, match="torch_xla"):
4145
cls()
46+
47+
# required to prevent env-var leakage
48+
if hasattr(cls, "teardown"):
49+
cls.teardown()

tests/tests_pytorch/utilities/migration/test_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ def test_patch_legacy_imports_unified(pl_version):
125125
assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), (
126126
f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}"
127127
)
128-
assert not any(key.startswith("pytorch_lightning") for key in sys.modules), (
128+
assert not any(
129+
key.startswith("pytorch_lightning") and not key.startswith("pytorch_lightning_enterprise")
130+
for key in sys.modules
131+
), (
129132
"Should not import standalone package, all imports should be redirected to the unified package;\n"
130133
f" environment: {_list_sys_modules('pytorch_lightning')}"
131134
)

0 commit comments

Comments
 (0)