Skip to content

Commit 2a36c4c

Browse files
committed
closure
1 parent de5e1a8 commit 2a36c4c

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

lightllm/distributed/communication_op.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,6 @@ def lightllm_capture_graph():
4949
pass
5050

5151

52-
def _all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
53-
if op != ReduceOp.SUM or async_op:
54-
original_all_reduce(input_, op, group, async_op)
55-
else:
56-
if vllm_reduce is not None:
57-
can_use = vllm_reduce.should_custom_ar(input_)
58-
if can_use:
59-
input_.data = vllm_reduce.custom_all_reduce(input_)
60-
return
61-
original_all_reduce(input_, op, group, async_op)
62-
63-
6452
def set_custom_reduce():
6553
global vllm_reduce
6654
global device_group
@@ -77,4 +65,14 @@ def set_custom_reduce():
7765
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
7866
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
7967
logger.info("Enable VLLM ALLReduce.")
80-
dist.all_reduce = partial(_all_reduce, group=device_group)
68+
69+
def _all_reduce_closure(input_, op=ReduceOp.SUM, group=device_group, async_op=False):
70+
if op != ReduceOp.SUM or async_op:
71+
original_all_reduce(input_, op, group, async_op)
72+
else:
73+
if vllm_reduce is not None and vllm_reduce.should_custom_ar(input_):
74+
input_.data = vllm_reduce.custom_all_reduce(input_)
75+
else:
76+
original_all_reduce(input_, op, group, async_op)
77+
78+
dist.all_reduce = _all_reduce_closure

0 commit comments

Comments
 (0)