1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import os
1514from typing import Optional , Union
1615
17- import torch
1816from typing_extensions import override
1917
2018import lightning .pytorch as pl
21- from lightning .fabric .accelerators .xla import _XLA_AVAILABLE
2219from lightning .fabric .plugins import CheckpointIO , Precision , XLACheckpointIO
2320from lightning .fabric .strategies import _StrategyRegistry
24- from lightning .fabric .utilities .optimizer import _optimizers_to_device
21+ from lightning .fabric .utilities .imports import _raise_enterprise_not_available
2522from lightning .fabric .utilities .types import _DEVICE
2623from lightning .pytorch .plugins .io .wrapper import _WrappingCheckpointIO
2724from lightning .pytorch .plugins .precision .xla import XLAPrecision
2825from lightning .pytorch .strategies .single_device import SingleDeviceStrategy
29- from lightning .pytorch .trainer .states import TrainerFn
30- from lightning .pytorch .utilities import find_shared_parameters , set_shared_parameters
3126
3227
3328class SingleDeviceXLAStrategy (SingleDeviceStrategy ):
@@ -41,20 +36,18 @@ def __init__(
4136 precision_plugin : Optional [XLAPrecision ] = None ,
4237 debug : bool = False ,
4338 ):
44- if not _XLA_AVAILABLE :
45- raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
46- if isinstance (device , torch .device ):
47- # unwrap the `torch.device` in favor of `xla_device`
48- device = device .index
49- import torch_xla .core .xla_model as xm
50-
5139 super ().__init__ (
5240 accelerator = accelerator ,
53- device = xm . xla_device ( device ) ,
41+ device = device ,
5442 checkpoint_io = checkpoint_io ,
5543 precision_plugin = precision_plugin ,
5644 )
57- self .debug = debug
45+ _raise_enterprise_not_available ()
46+ from pytorch_lightning_enterprise .strategies .xla .single import (
47+ SingleDeviceXLAStrategyTrainer as EnterpriseSingleDeviceXLAStrategy ,
48+ )
49+
50+ self .single_xla_strategy_impl = EnterpriseSingleDeviceXLAStrategy (outer_object = self , device = device , debug = debug )
5851
5952 @property
6053 @override
@@ -90,26 +83,7 @@ def precision_plugin(self, precision_plugin: Optional[Precision]) -> None:
9083
9184 @override
9285 def setup (self , trainer : "pl.Trainer" ) -> None :
93- if self .debug :
94- os .environ ["PT_XLA_DEBUG" ] = str (1 )
95-
96- assert self .accelerator is not None
97- self .accelerator .setup (trainer )
98-
99- assert self .model is not None
100- self .precision_plugin .convert_module (self .model )
101-
102- shared_params = find_shared_parameters (self .model )
103- self .model_to_device ()
104- set_shared_parameters (self .model , shared_params )
105-
106- self .model = self ._setup_model (self .model )
107-
108- if trainer .state .fn == TrainerFn .FITTING :
109- self .setup_optimizers (trainer )
110- self .setup_precision_plugin ()
111- if trainer .state .fn == TrainerFn .FITTING :
112- _optimizers_to_device (self .optimizers , self .root_device )
86+ return self .single_xla_strategy_impl .setup (trainer = trainer )
11387
11488 @classmethod
11589 @override
@@ -118,5 +92,4 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
11892
11993 @override
12094 def teardown (self ) -> None :
121- super ().teardown ()
122- os .environ .pop ("PT_XLA_DEBUG" , None )
95+ return self .single_xla_strategy_impl .teardown ()
0 commit comments