1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import os
1514import queue
1615from typing import TYPE_CHECKING , Any , Callable , Optional , Union
1716
1817import torch .multiprocessing as mp
1918from typing_extensions import override
2019
21- from lightning .fabric .accelerators .xla import _XLA_AVAILABLE
22- from lightning .fabric .strategies .launchers .xla import _rank_teardown
23- from lightning .fabric .utilities import move_data_to_device
2420from lightning .pytorch .strategies .launchers .multiprocessing import (
2521 _GlobalStateSnapshot ,
2622 _MultiProcessingLauncher ,
2723 _WorkerOutput ,
2824)
29- from lightning .pytorch .trainer .states import TrainerFn
30- from lightning .pytorch .utilities .rank_zero import rank_zero_debug
25+ from lightning .pytorch .utilities .imports import _raise_if_not_enterprise_not_available
3126
3227if TYPE_CHECKING :
3328 import lightning .pytorch as pl
@@ -51,14 +46,16 @@ class _XLALauncher(_MultiProcessingLauncher):
5146 """
5247
5348 def __init__ (self , strategy : "pl.strategies.XLAStrategy" ) -> None :
54- if not _XLA_AVAILABLE :
55- raise ModuleNotFoundError (str (_XLA_AVAILABLE ))
56- super ().__init__ (strategy = strategy , start_method = "fork" )
49+ super ().__init__ (strategy )
50+ _raise_if_not_enterprise_not_available ()
51+ from pytorch_lightning_enterprise .strategies .xla .launcher import _XLALauncherTrainer as EnterpriseXLALauncher
52+
53+ self .xla_launcher_impl = EnterpriseXLALauncher (strategy )
5754
5855 @property
5956 @override
6057 def is_interactive_compatible (self ) -> bool :
61- return True
58+ return self . xla_launcher_impl . is_interactive_compatible ()
6259
6360 @override
6461 def launch (self , function : Callable , * args : Any , trainer : Optional ["pl.Trainer" ] = None , ** kwargs : Any ) -> Any :
@@ -75,46 +72,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7572 **kwargs: Optional keyword arguments to be passed to the given function.
7673
7774 """
78- if self ._already_fit and trainer is not None and trainer .state .fn == TrainerFn .FITTING :
79- # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction
80- raise NotImplementedError (
81- "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
82- " supported. You can work around this by creating a new Trainer instance and passing the"
83- " `fit(ckpt_path=...)` argument."
84- )
85-
86- # pjrt requires that the queue is serializable
87- return_queue = mp .Manager ().Queue ()
88-
89- import torch_xla .distributed .xla_multiprocessing as xmp
90-
91- spawn_kwargs = {}
92- nprocs = self ._strategy .num_processes
93- if nprocs == 1 :
94- # avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
95- # otherwise it will use all devices
96- spawn_kwargs ["nprocs" ] = nprocs
97-
98- process_context = xmp .spawn (
99- self ._wrapping_function ,
100- args = (trainer , function , args , kwargs , return_queue ),
101- start_method = self ._start_method ,
102- join = False , # we will join ourselves to get the process references
103- ** spawn_kwargs ,
104- )
105- # xla will not actually create processes if only 1 device
106- if process_context is not None :
107- self .procs = process_context .processes
108- while not process_context .join ():
109- pass
110-
111- worker_output = return_queue .get ()
112- if trainer is None :
113- return worker_output
114-
115- self ._already_fit |= trainer .state .fn == TrainerFn .FITTING
116- self ._recover_results_in_main_process (worker_output , trainer )
117- return worker_output .trainer_results
75+ return self .xla_launcher_impl .launch (function , * args , trainer = trainer , ** kwargs )
11876
11977 @override
12078 def _wrapping_function (
@@ -129,48 +87,10 @@ def _wrapping_function(
12987 return_queue : Union [mp .SimpleQueue , queue .Queue ],
13088 global_states : Optional [_GlobalStateSnapshot ] = None ,
13189 ) -> None :
132- import torch_xla .core .xla_model as xm
133-
134- if len (xm .get_xla_supported_devices ()) > 1 :
135- # `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
136- # so when there's more than one (multithreading), objects need to be deep-copied
137- import copy
138-
139- trainer , function , args , kwargs = copy .deepcopy ((trainer , function , args , kwargs ))
140-
141- results = function (* args , ** kwargs )
142-
143- if trainer is not None :
144- results = self ._collect_rank_zero_results (trainer , results )
145-
146- if self ._strategy .local_rank == 0 :
147- return_queue .put (move_data_to_device (results , "cpu" ))
148-
149- _rank_teardown (self ._strategy .local_rank )
90+ return self .xla_launcher_impl ._wrapping_function (
91+ process_idx , trainer , function , args , kwargs , return_queue , global_states
92+ )
15093
15194 @override
15295 def _collect_rank_zero_results (self , trainer : "pl.Trainer" , results : Any ) -> Optional ["_WorkerOutput" ]:
153- rank_zero_debug ("Collecting results from rank 0 process." )
154- checkpoint_callback = trainer .checkpoint_callback
155- best_model_path = (
156- checkpoint_callback .best_model_path
157- if checkpoint_callback and hasattr (checkpoint_callback , "best_model_path" )
158- else None
159- )
160-
161- # save the last weights
162- weights_path = None
163- if trainer .state .fn == TrainerFn .FITTING :
164- # requires to compute the state_dict on all processes in case Metrics are present
165- state_dict = self ._strategy .lightning_module_state_dict ()
166- weights_path = os .path .join (trainer .default_root_dir , ".temp.ckpt" )
167- self ._strategy .checkpoint_io .save_checkpoint (state_dict , weights_path )
168-
169- # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
170- if self ._strategy .local_rank != 0 :
171- return None
172-
173- # add extra result data from trainer to send to main process
174- extra = self .get_extra_results (trainer )
175-
176- return _WorkerOutput (best_model_path , weights_path , trainer .state , results , extra )
96+ return self .xla_launcher_impl ._collect_rank_zero_results (trainer , results )
0 commit comments