Skip to content

Commit b188440

Browse files
authored
fix grpo ddp hang (#3476)
* update * fix * rename --------- Co-authored-by: hjh <[email protected]>
1 parent 6e982d7 commit b188440

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

swift/llm/infer/infer_engine/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def map_rank_to_real_device(obj):
433433
GroupCoordinator.__init__ = __init__
434434

435435
try:
436-
with profiling_patch, set_local_rank_context(vllm_device):
436+
with profiling_patch, restore_torch_device_after_vllm_init(), set_local_rank_context(vllm_device):
437437
torch.distributed.get_world_size_origin = torch.distributed.get_world_size
438438
torch.distributed.get_world_size = get_world_size
439439
yield
@@ -486,3 +486,23 @@ def set_local_rank_context(device: Union[str, int]):
486486
os.environ['LOCAL_RANK'] = origin_local_rank
487487
else:
488488
del os.environ['LOCAL_RANK']
489+
490+
491+
@contextmanager
492+
def restore_torch_device_after_vllm_init():
493+
"""
494+
A context manager to restore the original CUDA device after potential modifications.
495+
496+
This is specifically designed to address an issue in Distributed Data Parallel (DDP)
497+
scenarios where the initialization of the vLLM engine may inadvertently modify the
498+
default CUDA device. The context manager saves the current device at the start and
499+
ensures it is restored upon exit, even if the device is modified within the context.
500+
501+
"""
502+
origin_device = torch.cuda.current_device()
503+
try:
504+
yield
505+
finally:
506+
current_device = torch.cuda.current_device()
507+
if origin_device != current_device:
508+
torch.cuda.set_device(origin_device)

0 commit comments

Comments
 (0)