|
19 | 19 | from unittest.mock import Mock |
20 | 20 |
|
21 | 21 | import pytest |
| 22 | +import pytorch_lightning_enterprise |
22 | 23 | import torch.distributed |
23 | 24 |
|
24 | 25 | import lightning.fabric |
@@ -144,13 +145,8 @@ def reset_cudnn_benchmark(): |
144 | 145 |
|
145 | 146 |
|
146 | 147 | def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: |
| 148 | + monkeypatch.setattr(pytorch_lightning_enterprise.utilities.imports, "_XLA_AVAILABLE", value) |
147 | 149 | monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) |
148 | | - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) |
149 | | - monkeypatch.setattr(lightning.fabric.plugins.precision.xla, "_XLA_AVAILABLE", value) |
150 | | - monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) |
151 | | - monkeypatch.setattr(lightning.fabric.strategies.single_xla, "_XLA_AVAILABLE", value) |
152 | | - monkeypatch.setattr(lightning.fabric.strategies.xla_fsdp, "_XLA_AVAILABLE", value) |
153 | | - monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) |
154 | 150 | monkeypatch.setitem(sys.modules, "torch_xla", Mock()) |
155 | 151 | monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) |
156 | 152 | monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) |
|
0 commit comments