@@ -153,6 +153,8 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
153153 monkeypatch .setitem (sys .modules , "torch_xla.core.xla_model" , Mock ())
154154 monkeypatch .setitem (sys .modules , "torch_xla.experimental" , Mock ())
155155 monkeypatch .setitem (sys .modules , "torch_xla.distributed.fsdp.wrap" , Mock ())
156+ monkeypatch .setitem (sys .modules , "torch_xla._internal" , Mock ())
157+ monkeypatch .setitem (sys .modules , "torch_xla._internal.tpu" , Mock ())
156158
157159 # Then patch the _XLA_AVAILABLE flags in various modules
158160 monkeypatch .setattr (pytorch_lightning_enterprise .utils .imports , "_XLA_AVAILABLE" , value )
@@ -161,6 +163,12 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
161163 monkeypatch .setattr (lightning .fabric .accelerators .xla , "_XLA_AVAILABLE" , value )
162164 monkeypatch .setattr (lightning .fabric .accelerators .xla , "_XLA_GREATER_EQUAL_2_1" , value )
163165 monkeypatch .setattr (lightning .fabric .accelerators .xla , "_XLA_GREATER_EQUAL_2_5" , value )
166+ # Patch in the modules where they're used after import
167+ monkeypatch .setattr ("pytorch_lightning_enterprise.accelerators.xla._XLA_AVAILABLE" , value )
168+ monkeypatch .setattr ("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_1" , value )
169+ monkeypatch .setattr ("pytorch_lightning_enterprise.accelerators.xla._XLA_GREATER_EQUAL_2_5" , value )
170+ monkeypatch .setattr ("pytorch_lightning_enterprise.plugins.environments.xla._XLA_AVAILABLE" , value )
171+ monkeypatch .setattr ("pytorch_lightning_enterprise.plugins.environments.xla._XLA_GREATER_EQUAL_2_1" , value )
164172
165173
166174@pytest .fixture
@@ -172,6 +180,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N
172180 mock_xla_available (monkeypatch , value )
173181 monkeypatch .setattr (lightning .fabric .accelerators .xla .XLAAccelerator , "is_available" , lambda : value )
174182 monkeypatch .setattr (lightning .fabric .accelerators .xla .XLAAccelerator , "auto_device_count" , lambda * _ : 8 )
183+ # Also mock the enterprise XLAAccelerator methods
184+ import pytorch_lightning_enterprise .accelerators .xla
185+
186+ monkeypatch .setattr (pytorch_lightning_enterprise .accelerators .xla .XLAAccelerator , "is_available" , lambda : value )
187+ monkeypatch .setattr (pytorch_lightning_enterprise .accelerators .xla .XLAAccelerator , "auto_device_count" , lambda * _ : 8 )
175188
176189
177190@pytest .fixture
0 commit comments