Skip to content

Commit e87821c

Browse files
committed
refactor reduce
1 parent 1a52796 commit e87821c

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

lightllm/distributed/communication_op.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.distributed as dist
2525
from torch.distributed import ReduceOp
2626
from lightllm.utils.log_utils import init_logger
27+
from functools import partial
2728

2829
original_all_reduce = torch.distributed.all_reduce
2930
from contextlib import nullcontext, contextmanager
@@ -47,31 +48,30 @@ def lightllm_capture_graph():
4748
pass
4849

4950
def _all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
50-
if op != ReduceOp.SUM or group is not None or async_op:
51+
if op != ReduceOp.SUM or async_op:
5152
original_all_reduce(input_, op, group, async_op)
5253
else:
5354
if vllm_reduce is not None:
5455
can_use = vllm_reduce.should_custom_ar(input_)
5556
if can_use:
5657
input_.data = vllm_reduce.custom_all_reduce(input_)
5758
return
58-
original_all_reduce(input_, op, vllm_reduce.device_group, async_op)
59-
else:
60-
original_all_reduce(input_, op, group, async_op)
61-
59+
original_all_reduce(input_, op, group, async_op)
6260

6361
def set_custom_reduce():
6462
global vllm_reduce
63+
global device_group
6564
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [
6665
"ON",
6766
"TRUE",
6867
"1",
6968
]
69+
world_size = dist.get_world_size()
70+
ranks = list(range(world_size))
71+
# new_group prevent stuck of torch origin all_reduce with cudagraph
72+
device_group = torch.distributed.new_group(ranks, backend="nccl")
7073
if ENABLE_VLLM_REDUCE and HAS_VLLM:
71-
world_size = dist.get_world_size()
72-
ranks = list(range(world_size))
73-
device_group = torch.distributed.new_group(ranks, backend="nccl")
7474
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
75-
vllm_reduce = CustomAllreduce(cpu_group, device_group, torch.cuda.current_device())
75+
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
7676
logger.info("Enable VLLM ALLReduce.")
77-
dist.all_reduce = _all_reduce
77+
dist.all_reduce = partial(_all_reduce, group=device_group)

lightllm/distributed/custom_all_reduce.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CustomAllreduce:
4949
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
5050

5151
# max_size: max supported allreduce size
52-
def __init__(self, group: ProcessGroup, device_group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024) -> None:
52+
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024) -> None:
5353
"""
5454
Args:
5555
group: the process group to work on. If None, it will use the
@@ -69,7 +69,6 @@ def __init__(self, group: ProcessGroup, device_group: ProcessGroup, device: Unio
6969
return
7070

7171
self.group = group
72-
self.device_group = device_group
7372
assert dist.get_backend(group) != dist.Backend.NCCL, "CustomAllreduce should be attached to a non-NCCL group."
7473

7574
rank = dist.get_rank(group=self.group)

0 commit comments

Comments
 (0)