File tree Expand file tree Collapse file tree 1 file changed +21
-1
lines changed
swift/llm/infer/infer_engine Expand file tree Collapse file tree 1 file changed +21
-1
lines changed Original file line number Diff line number Diff line change @@ -433,7 +433,7 @@ def map_rank_to_real_device(obj):
433433 GroupCoordinator .__init__ = __init__
434434
435435 try :
436- with profiling_patch , set_local_rank_context (vllm_device ):
436+ with profiling_patch , restore_torch_device_after_vllm_init (), set_local_rank_context (vllm_device ):
437437 torch .distributed .get_world_size_origin = torch .distributed .get_world_size
438438 torch .distributed .get_world_size = get_world_size
439439 yield
@@ -486,3 +486,23 @@ def set_local_rank_context(device: Union[str, int]):
486486 os .environ ['LOCAL_RANK' ] = origin_local_rank
487487 else :
488488 del os .environ ['LOCAL_RANK' ]
489+
490+
491+ @contextmanager
492+ def restore_torch_device_after_vllm_init ():
493+ """
494+ A context manager to restore the original CUDA device after potential modifications.
495+
496+ This is specifically designed to address an issue in Distributed Data Parallel (DDP)
497+ scenarios where the initialization of the vLLM engine may inadvertently modify the
498+ default CUDA device. The context manager saves the current device at the start and
499+ ensures it is restored upon exit, even if the device is modified within the context.
500+
501+ """
502+ origin_device = torch .cuda .current_device ()
503+ try :
504+ yield
505+ finally :
506+ current_device = torch .cuda .current_device ()
507+ if origin_device != current_device :
508+ torch .cuda .set_device (origin_device )
You can’t perform that action at this time.
0 commit comments