Skip to content

Commit e68c226

Browse files
committed
forward trainer xla single device
1 parent ed8f618 commit e68c226

File tree

1 file changed

+10
-37
lines changed

1 file changed

+10
-37
lines changed

src/lightning/pytorch/strategies/single_xla.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
from typing import Optional, Union
1615

17-
import torch
1816
from typing_extensions import override
1917

2018
import lightning.pytorch as pl
21-
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
2219
from lightning.fabric.plugins import CheckpointIO, Precision, XLACheckpointIO
2320
from lightning.fabric.strategies import _StrategyRegistry
24-
from lightning.fabric.utilities.optimizer import _optimizers_to_device
21+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2522
from lightning.fabric.utilities.types import _DEVICE
2623
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
2724
from lightning.pytorch.plugins.precision.xla import XLAPrecision
2825
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
29-
from lightning.pytorch.trainer.states import TrainerFn
30-
from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters
3126

3227

3328
class SingleDeviceXLAStrategy(SingleDeviceStrategy):
@@ -41,20 +36,18 @@ def __init__(
4136
precision_plugin: Optional[XLAPrecision] = None,
4237
debug: bool = False,
4338
):
44-
if not _XLA_AVAILABLE:
45-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
46-
if isinstance(device, torch.device):
47-
# unwrap the `torch.device` in favor of `xla_device`
48-
device = device.index
49-
import torch_xla.core.xla_model as xm
50-
5139
super().__init__(
5240
accelerator=accelerator,
53-
device=xm.xla_device(device),
41+
device=device,
5442
checkpoint_io=checkpoint_io,
5543
precision_plugin=precision_plugin,
5644
)
57-
self.debug = debug
45+
_raise_enterprise_not_available()
46+
from pytorch_lightning_enterprise.strategies.xla.single import (
47+
SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy,
48+
)
49+
50+
self.single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy(outer_object=self, device=device, debug=debug)
5851

5952
@property
6053
@override
@@ -90,26 +83,7 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
9083

9184
@override
9285
def setup(self, trainer: "pl.Trainer") -> None:
93-
if self.debug:
94-
os.environ["PT_XLA_DEBUG"] = str(1)
95-
96-
assert self.accelerator is not None
97-
self.accelerator.setup(trainer)
98-
99-
assert self.model is not None
100-
self.precision_plugin.convert_module(self.model)
101-
102-
shared_params = find_shared_parameters(self.model)
103-
self.model_to_device()
104-
set_shared_parameters(self.model, shared_params)
105-
106-
self.model = self._setup_model(self.model)
107-
108-
if trainer.state.fn == TrainerFn.FITTING:
109-
self.setup_optimizers(trainer)
110-
self.setup_precision_plugin()
111-
if trainer.state.fn == TrainerFn.FITTING:
112-
_optimizers_to_device(self.optimizers, self.root_device)
86+
return self.single_xla_strategy_impl.setup(trainer=trainer)
11387

11488
@classmethod
11589
@override
@@ -118,5 +92,4 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
11892

11993
@override
12094
def teardown(self) -> None:
121-
super().teardown()
122-
os.environ.pop("PT_XLA_DEBUG", None)
95+
return self.single_xla_strategy_impl.teardown()

0 commit comments

Comments
 (0)