Skip to content

Commit 0348b70

Browse files
authored
using spawn instead of fork for XPU device (#3884)
* using `spawn` instead of `fork` for XPU device Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update code Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * remove comment Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
1 parent fa6e13d commit 0348b70

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

src/accelerate/launchers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,12 @@ def train(*args):
206206
# First dummy launch
207207
# Determine device type without initializing any device (which would break fork)
208208
device_type, distributed_type = get_current_device_type()
209+
# XPU requires spawn instead of fork
210+
start_method = "spawn" if device_type == "xpu" else "fork"
209211
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
210212
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
211213
try:
212-
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
214+
start_processes(launcher, args=(), nprocs=num_processes, start_method=start_method)
213215
except ProcessRaisedException as e:
214216
err = "An issue was found when verifying a stable environment for the notebook launcher."
215217
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
@@ -241,7 +243,7 @@ def train(*args):
241243
rdzv_configs=rdzv_conf,
242244
max_restarts=max_restarts,
243245
monitor_interval=monitor_interval,
244-
start_method="fork",
246+
start_method=start_method,
245247
)
246248
if is_torch_version(">=", ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION):
247249
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template

src/accelerate/test_utils/scripts/test_notebook.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,21 @@
1717

1818
import os
1919
import time
20-
from multiprocessing import Queue
2120

2221
from pytest import mark, raises
2322
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
2423

2524
from accelerate import PartialState, notebook_launcher
2625
from accelerate.test_utils import require_bnb
27-
from accelerate.utils import is_bnb_available
26+
from accelerate.utils import is_bnb_available, is_xpu_available
2827

2928

3029
def basic_function():
3130
# Just prints the PartialState
3231
print(f"PartialState:\n{PartialState()}")
3332

3433

35-
def tough_nut_function(queue: Queue):
34+
def tough_nut_function(queue):
3635
if queue.empty():
3736
return
3837
trial = queue.get()
@@ -70,7 +69,15 @@ def test_c10d_rdzv_backend():
7069

7170
@mark.skipif(NUM_PROCESSES < 2, reason="Need at least 2 processes to test fault tolerance")
7271
def test_fault_tolerant(max_restarts: int = 3):
73-
queue = Queue()
72+
# Use torch.multiprocessing to get the right context for the current device
73+
import torch.multiprocessing as mp
74+
75+
# Get appropriate context - 'spawn' for XPU, 'fork' for others
76+
if is_xpu_available():
77+
ctx = mp.get_context("spawn")
78+
else:
79+
ctx = mp.get_context("fork")
80+
queue = ctx.Queue()
7481
queue.put(max_restarts)
7582
notebook_launcher(tough_nut_function, (queue,), num_processes=NUM_PROCESSES, max_restarts=max_restarts)
7683

0 commit comments

Comments
 (0)