@@ -49,18 +49,6 @@ def lightllm_capture_graph():
4949 pass
5050
5151
52- def _all_reduce (input_ , op = ReduceOp .SUM , group = None , async_op = False ):
53- if op != ReduceOp .SUM or async_op :
54- original_all_reduce (input_ , op , group , async_op )
55- else :
56- if vllm_reduce is not None :
57- can_use = vllm_reduce .should_custom_ar (input_ )
58- if can_use :
59- input_ .data = vllm_reduce .custom_all_reduce (input_ )
60- return
61- original_all_reduce (input_ , op , group , async_op )
62-
63-
6452def set_custom_reduce ():
6553 global vllm_reduce
6654 global device_group
@@ -77,4 +65,14 @@ def set_custom_reduce():
7765 cpu_group = torch .distributed .new_group (ranks , backend = "gloo" )
7866 vllm_reduce = CustomAllreduce (cpu_group , torch .cuda .current_device ())
7967 logger .info ("Enable VLLM ALLReduce." )
80- dist .all_reduce = partial (_all_reduce , group = device_group )
68+
69+ def _all_reduce_closure (input_ , op = ReduceOp .SUM , group = device_group , async_op = False ):
70+ if op != ReduceOp .SUM or async_op :
71+ original_all_reduce (input_ , op , group , async_op )
72+ else :
73+ if vllm_reduce is not None and vllm_reduce .should_custom_ar (input_ ):
74+ input_ .data = vllm_reduce .custom_all_reduce (input_ )
75+ else :
76+ original_all_reduce (input_ , op , group , async_op )
77+
78+ dist .all_reduce = _all_reduce_closure
0 commit comments