Skip to content

Commit 8666d8d

Browse files
committed
fix transformer engine mock
1 parent ab4d2a0 commit 8666d8d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tests/tests_pytorch/deprecated_api/test_no_removal_version.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def test_precision_plugin():
109109

110110

111111
def test_transformer_engine_precision_plugin(monkeypatch):
112-
monkeypatch.setattr(lightning.fabric.utilities.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
113-
monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_TRANSFORMER_ENGINE_AVAILABLE", True)
112+
monkeypatch.setattr(
113+
pytorch_lightning_enterprise.plugins.precision.transformer_engine, "_TRANSFORMER_ENGINE_AVAILABLE", True
114+
)
114115
transformer_engine_mock = Mock()
115116
monkeypatch.setitem(sys.modules, "transformer_engine", transformer_engine_mock)
116117
monkeypatch.setitem(sys.modules, "transformer_engine.pytorch", Mock())

0 commit comments

Comments
 (0)