Skip to content

Commit 8d5ac23

Browse files
committed
update
1 parent 304f82b commit 8d5ac23

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
5151
_DEEPSPEED_GREATER_EQUAL_0_16 = RequirementCache("deepspeed>=0.16.0")
5252
_ENTERPRISE_AVAILABLE = RequirementCache("pytorch_lightning_enterprise")
53+
_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
5354

5455

5556
def _raise_enterprise_not_available() -> None:

tests/tests_pytorch/deprecated_api/test_no_removal_version.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock
33

44
import pytest
5+
import pytorch_lightning_enterprise.utils.imports
56
import torch.nn
67

78
import lightning.fabric
@@ -62,6 +63,7 @@ def test_fsdp_precision_plugin():
6263

6364
def test_bitsandbytes_precision_plugin(monkeypatch):
6465
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
66+
monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_BITSANDBYTES_AVAILABLE", True)
6567
bitsandbytes_mock = Mock()
6668
monkeypatch.setitem(sys.modules, "bitsandbytes", bitsandbytes_mock)
6769

@@ -107,7 +109,8 @@ def test_precision_plugin():
107109

108110

109111
def test_transformer_engine_precision_plugin(monkeypatch):
110-
monkeypatch.setattr(lightning.fabric.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True)
112+
monkeypatch.setattr(lightning.fabric.utilities.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
113+
monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
111114
transformer_engine_mock = Mock()
112115
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
113116
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())

tests/tests_pytorch/plugins/precision/test_transformer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
def test_transformer_engine_precision_plugin(monkeypatch):
28-
module = lightning.fabric.plugins.precision.transformer_engine
28+
module = lightning.fabric.utilities.imports
2929
if module._TRANSFORMER_ENGINE_AVAILABLE:
3030
pytest.skip("Assumes transformer_engine is unavailable")
3131
monkeypatch.setattr(module, "_TRANSFORMER_ENGINE_AVAILABLE", lambda: True)

0 commit comments

Comments
 (0)