Skip to content

Commit 975c098

Browse files
committed
update
1 parent 3aa1981 commit 975c098

File tree

2 files changed

+15
-92
lines changed
  • src/lightning

2 files changed

+15
-92
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
from typing import Any, Union
1717

1818
import torch
19+
from lightning_utilities.core.imports import RequirementCache
1920
from typing_extensions import override
2021

2122
from lightning.fabric.accelerators.accelerator import Accelerator
2223
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2324
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2425

26+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
27+
2528

2629
class XLAAccelerator(Accelerator):
2730
"""Accelerator for XLA devices, normally TPUs.

src/lightning/pytorch/strategies/launchers/xla.py

Lines changed: 12 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
import queue
1615
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
1716

1817
import torch.multiprocessing as mp
1918
from typing_extensions import override
2019

21-
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
22-
from lightning.fabric.strategies.launchers.xla import _rank_teardown
23-
from lightning.fabric.utilities import move_data_to_device
2420
from lightning.pytorch.strategies.launchers.multiprocessing import (
2521
_GlobalStateSnapshot,
2622
_MultiProcessingLauncher,
2723
_WorkerOutput,
2824
)
29-
from lightning.pytorch.trainer.states import TrainerFn
30-
from lightning.pytorch.utilities.rank_zero import rank_zero_debug
25+
from lightning.pytorch.utilities.imports import _raise_if_not_enterprise_not_available
3126

3227
if TYPE_CHECKING:
3328
import lightning.pytorch as pl
@@ -51,14 +46,16 @@ class _XLALauncher(_MultiProcessingLauncher):
5146
"""
5247

5348
def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
54-
if not _XLA_AVAILABLE:
55-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
56-
super().__init__(strategy=strategy, start_method="fork")
49+
super().__init__(strategy)
50+
_raise_if_not_enterprise_not_available()
51+
from pytorch_lightning_enterprise.strategies.xla.launcher import _XLALauncherTrainer as EnterpriseXLALauncher
52+
53+
self.xla_launcher_impl = EnterpriseXLALauncher(strategy)
5754

5855
@property
5956
@override
6057
def is_interactive_compatible(self) -> bool:
61-
return True
58+
return self.xla_launcher_impl.is_interactive_compatible()
6259

6360
@override
6461
def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
@@ -75,46 +72,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
7572
**kwargs: Optional keyword arguments to be passed to the given function.
7673
7774
"""
78-
if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
79-
# resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction
80-
raise NotImplementedError(
81-
"Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
82-
" supported. You can work around this by creating a new Trainer instance and passing the"
83-
" `fit(ckpt_path=...)` argument."
84-
)
85-
86-
# pjrt requires that the queue is serializable
87-
return_queue = mp.Manager().Queue()
88-
89-
import torch_xla.distributed.xla_multiprocessing as xmp
90-
91-
spawn_kwargs = {}
92-
nprocs = self._strategy.num_processes
93-
if nprocs == 1:
94-
# avoid warning: "Unsupported nprocs". If it's 1, it will call the launched function directly.
95-
# otherwise it will use all devices
96-
spawn_kwargs["nprocs"] = nprocs
97-
98-
process_context = xmp.spawn(
99-
self._wrapping_function,
100-
args=(trainer, function, args, kwargs, return_queue),
101-
start_method=self._start_method,
102-
join=False, # we will join ourselves to get the process references
103-
**spawn_kwargs,
104-
)
105-
# xla will not actually create processes if only 1 device
106-
if process_context is not None:
107-
self.procs = process_context.processes
108-
while not process_context.join():
109-
pass
110-
111-
worker_output = return_queue.get()
112-
if trainer is None:
113-
return worker_output
114-
115-
self._already_fit |= trainer.state.fn == TrainerFn.FITTING
116-
self._recover_results_in_main_process(worker_output, trainer)
117-
return worker_output.trainer_results
75+
return self.xla_launcher_impl.launch(function, *args, trainer=trainer, **kwargs)
11876

11977
@override
12078
def _wrapping_function(
@@ -129,48 +87,10 @@ def _wrapping_function(
12987
return_queue: Union[mp.SimpleQueue, queue.Queue],
13088
global_states: Optional[_GlobalStateSnapshot] = None,
13189
) -> None:
132-
import torch_xla.core.xla_model as xm
133-
134-
if len(xm.get_xla_supported_devices()) > 1:
135-
# `get_xla_supported_devices` in the spawned process returns the logical devices (2 for v2/v3 and 1 for v4)
136-
# so when there's more than one (multithreading), objects need to be deep-copied
137-
import copy
138-
139-
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
140-
141-
results = function(*args, **kwargs)
142-
143-
if trainer is not None:
144-
results = self._collect_rank_zero_results(trainer, results)
145-
146-
if self._strategy.local_rank == 0:
147-
return_queue.put(move_data_to_device(results, "cpu"))
148-
149-
_rank_teardown(self._strategy.local_rank)
90+
return self.xla_launcher_impl._wrapping_function(
91+
process_idx, trainer, function, args, kwargs, return_queue, global_states
92+
)
15093

15194
@override
15295
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
153-
rank_zero_debug("Collecting results from rank 0 process.")
154-
checkpoint_callback = trainer.checkpoint_callback
155-
best_model_path = (
156-
checkpoint_callback.best_model_path
157-
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
158-
else None
159-
)
160-
161-
# save the last weights
162-
weights_path = None
163-
if trainer.state.fn == TrainerFn.FITTING:
164-
# requires to compute the state_dict on all processes in case Metrics are present
165-
state_dict = self._strategy.lightning_module_state_dict()
166-
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
167-
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
168-
169-
# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
170-
if self._strategy.local_rank != 0:
171-
return None
172-
173-
# add extra result data from trainer to send to main process
174-
extra = self.get_extra_results(trainer)
175-
176-
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
96+
return self.xla_launcher_impl._collect_rank_zero_results(trainer, results)

0 commit comments

Comments
 (0)