2424import torch .distributed as dist
2525from torch .distributed import ReduceOp
2626from lightllm .utils .log_utils import init_logger
27+ from functools import partial
2728
2829original_all_reduce = torch .distributed .all_reduce
2930from contextlib import nullcontext , contextmanager
@@ -47,31 +48,30 @@ def lightllm_capture_graph():
4748 pass
4849
4950def _all_reduce (input_ , op = ReduceOp .SUM , group = None , async_op = False ):
50- if op != ReduceOp .SUM or group is not None or async_op :
51+ if op != ReduceOp .SUM or async_op :
5152 original_all_reduce (input_ , op , group , async_op )
5253 else :
5354 if vllm_reduce is not None :
5455 can_use = vllm_reduce .should_custom_ar (input_ )
5556 if can_use :
5657 input_ .data = vllm_reduce .custom_all_reduce (input_ )
5758 return
58- original_all_reduce (input_ , op , vllm_reduce .device_group , async_op )
59- else :
60- original_all_reduce (input_ , op , group , async_op )
61-
59+ original_all_reduce (input_ , op , group , async_op )
6260
6361def set_custom_reduce ():
6462 global vllm_reduce
63+ global device_group
6564 ENABLE_VLLM_REDUCE = os .getenv ("ENABLE_VLLM_REDUCE" , "False" ).upper () in [
6665 "ON" ,
6766 "TRUE" ,
6867 "1" ,
6968 ]
69+ world_size = dist .get_world_size ()
70+ ranks = list (range (world_size ))
71+ # new_group prevent stuck of torch origin all_reduce with cudagraph
72+ device_group = torch .distributed .new_group (ranks , backend = "nccl" )
7073 if ENABLE_VLLM_REDUCE and HAS_VLLM :
71- world_size = dist .get_world_size ()
72- ranks = list (range (world_size ))
73- device_group = torch .distributed .new_group (ranks , backend = "nccl" )
7474 cpu_group = torch .distributed .new_group (ranks , backend = "gloo" )
75- vllm_reduce = CustomAllreduce (cpu_group , device_group , torch .cuda .current_device ())
75+ vllm_reduce = CustomAllreduce (cpu_group , torch .cuda .current_device ())
7676 logger .info ("Enable VLLM ALLReduce." )
77- dist .all_reduce = _all_reduce
77+ dist .all_reduce = partial ( _all_reduce , group = device_group )
0 commit comments