|
22 | 22 | from unittest.mock import Mock |
23 | 23 |
|
24 | 24 | import pytest |
| 25 | +import pytorch_lightning_enterprise.utils.imports |
25 | 26 | import torch.distributed |
26 | 27 | from tqdm import TMonitor |
27 | 28 |
|
@@ -220,14 +221,9 @@ def mps_count_1(monkeypatch): |
220 | 221 |
|
221 | 222 |
|
222 | 223 | def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> None: |
223 | | - monkeypatch.setattr(lightning.pytorch.strategies.xla, "_XLA_AVAILABLE", value) |
224 | | - monkeypatch.setattr(lightning.pytorch.strategies.single_xla, "_XLA_AVAILABLE", value) |
225 | | - monkeypatch.setattr(lightning.pytorch.plugins.precision.xla, "_XLA_AVAILABLE", value) |
226 | | - monkeypatch.setattr(lightning.pytorch.strategies.launchers.xla, "_XLA_AVAILABLE", value) |
227 | | - monkeypatch.setattr(lightning.fabric.accelerators.xla, "_XLA_AVAILABLE", value) |
228 | | - monkeypatch.setattr(lightning.fabric.plugins.environments.xla, "_XLA_AVAILABLE", value) |
229 | | - monkeypatch.setattr(lightning.fabric.plugins.io.xla, "_XLA_AVAILABLE", value) |
230 | | - monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) |
| 224 | + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_AVAILABLE", value) |
| 225 | + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_1", value) |
| 226 | + monkeypatch.setattr(pytorch_lightning_enterprise.utils.imports, "_XLA_GREATER_EQUAL_2_5", value) |
231 | 227 |
|
232 | 228 |
|
233 | 229 | @pytest.fixture |
|
0 commit comments