5151
5252class CustomCommunicationOp :
5353 def __init__ (self ):
54- self .vllm_reduce1 = None
55- self .vllm_reduce2 = None
56- self .custom_gather = None
57- self .custom_gather2 = None
54+ self .reduce_num = 2
55+ self .vllm_reduce = [None ] * self .reduce_num
56+ self .custom_gather = [None ] * self .reduce_num
5857 self .device_group = None
5958
6059 @contextmanager
6160 def lightllm_capture_graph (self , all_reduce_id ):
62- if all_reduce_id == 0 :
63- if self .vllm_reduce1 is not None :
64- with self .vllm_reduce1 .capture ():
65- if self .custom_gather is not None :
66- with self .custom_gather .capture ():
67- yield
68- else :
61+ if self .vllm_reduce [all_reduce_id ] is not None :
62+ with self .vllm_reduce [all_reduce_id ].capture ():
63+ if self .custom_gather [all_reduce_id ] is not None :
64+ with self .custom_gather [all_reduce_id ].capture ():
6965 yield
70- else :
71- yield
66+ else :
67+ yield
7268 else :
73- if self .vllm_reduce2 is not None :
74- with self .vllm_reduce2 .capture ():
75- if self .custom_gather2 is not None :
76- with self .custom_gather2 .capture ():
77- yield
78- else :
79- yield
80- else :
81- yield
69+ yield
8270
8371 def set_custom_reduce (self ):
8472 ENABLE_VLLM_REDUCE = os .getenv ("ENABLE_VLLM_REDUCE" , "True" ).upper () in ["ON" , "TRUE" , "1" ]
@@ -97,17 +85,16 @@ def set_custom_reduce(self):
9785 self .device_group = dist .new_group (ranks , backend = "nccl" )
9886
9987 if ENABLE_VLLM_REDUCE and HAS_VLLM :
100- cpu_group1 = dist .new_group (ranks , backend = "gloo" )
101- self .vllm_reduce1 = CustomAllreduce (cpu_group1 , torch .cuda .current_device ())
102- cpu_group2 = dist .new_group (ranks , backend = "gloo" )
103- self .vllm_reduce2 = CustomAllreduce (cpu_group2 , torch .cuda .current_device ())
88+ cpu_group = [dist .new_group (ranks , backend = "gloo" )] * self .reduce_num
89+ for i in range (self .reduce_num ):
90+ self .vllm_reduce [i ] = CustomAllreduce (cpu_group [i ], torch .cuda .current_device ())
10491 logger .info ("Enable VLLM ALLReduce." )
10592
10693 def _all_reduce_closure (input_ , op = ReduceOp .SUM , group = self .device_group , async_op = False , all_reduce_id = 0 ):
10794 if op != ReduceOp .SUM or async_op :
10895 original_all_reduce (input_ , op , group , async_op )
10996 else :
110- vllm_reduce = self .vllm_reduce1 if all_reduce_id == 0 else self . vllm_reduce2
97+ vllm_reduce = self .vllm_reduce [ all_reduce_id ]
11198 if vllm_reduce is not None and vllm_reduce .should_custom_ar (input_ ):
11299 input_ .data = vllm_reduce .custom_all_reduce (input_ )
113100 else :
0 commit comments