Skip to content

Commit 1a52796

Browse files
authored
fix reduce stuck on h100 with graph (#654)
Co-authored-by: baishihao <[email protected]>
1 parent b03b60e commit 1a52796

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

lightllm/distributed/communication_op.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
vllm_reduce = None
3838
logger = init_logger(__name__)
3939

40-
4140
@contextmanager
4241
def lightllm_capture_graph():
4342
if vllm_reduce is not None:
@@ -47,22 +46,22 @@ def lightllm_capture_graph():
4746
yield
4847
pass
4948

50-
5149
def _all_reduce(input_, op=ReduceOp.SUM, group=None, async_op=False):
52-
if op != ReduceOp.SUM or group is not None or async_op or vllm_reduce is None:
50+
if op != ReduceOp.SUM or group is not None or async_op:
5351
original_all_reduce(input_, op, group, async_op)
5452
else:
5553
if vllm_reduce is not None:
5654
can_use = vllm_reduce.should_custom_ar(input_)
5755
if can_use:
5856
input_.data = vllm_reduce.custom_all_reduce(input_)
5957
return
60-
original_all_reduce(input_, op, group, async_op)
58+
original_all_reduce(input_, op, vllm_reduce.device_group, async_op)
59+
else:
60+
original_all_reduce(input_, op, group, async_op)
6161

6262

6363
def set_custom_reduce():
6464
global vllm_reduce
65-
6665
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [
6766
"ON",
6867
"TRUE",
@@ -71,7 +70,8 @@ def set_custom_reduce():
7170
if ENABLE_VLLM_REDUCE and HAS_VLLM:
7271
world_size = dist.get_world_size()
7372
ranks = list(range(world_size))
73+
device_group = torch.distributed.new_group(ranks, backend="nccl")
7474
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
75-
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
75+
vllm_reduce = CustomAllreduce(cpu_group, device_group, torch.cuda.current_device())
7676
logger.info("Enable VLLM ALLReduce.")
7777
dist.all_reduce = _all_reduce

lightllm/distributed/custom_all_reduce.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from lightllm.utils.log_utils import init_logger
3131
from vllm.platforms import current_platform
3232
from vllm.utils import cuda_device_count_stateless
33-
33+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
3434
ops.meta_size()
3535
custom_ar = True
3636

@@ -49,7 +49,7 @@ class CustomAllreduce:
4949
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
5050

5151
# max_size: max supported allreduce size
52-
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024) -> None:
52+
def __init__(self, group: ProcessGroup, device_group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024) -> None:
5353
"""
5454
Args:
5555
group: the process group to work on. If None, it will use the
@@ -69,7 +69,7 @@ def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], m
6969
return
7070

7171
self.group = group
72-
72+
self.device_group = device_group
7373
assert dist.get_backend(group) != dist.Backend.NCCL, "CustomAllreduce should be attached to a non-NCCL group."
7474

7575
rank = dist.get_rank(group=self.group)
@@ -226,7 +226,7 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None, registered:
226226
buffer.
227227
"""
228228
if out is None:
229-
out = torch.empty_like(inp)
229+
out = g_cache_manager.alloc_tensor(inp.shape, inp.dtype, device=inp.device, is_graph_out=False)
230230
if registered:
231231
ops.all_reduce(self._ptr, inp, out, 0, 0)
232232
else:
@@ -244,7 +244,8 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
244244
else:
245245
# If warm up, mimic the allocation pattern since custom
246246
# allreduce is out-of-place.
247-
return torch.empty_like(input)
247+
out = g_cache_manager.alloc_tensor(input.shape, input.dtype, device=input.device, is_graph_out=False)
248+
return out
248249
else:
249250
# Note: outside of cuda graph context, custom allreduce incurs a
250251
# cost of cudaMemcpy, which should be small (<=1% of overall

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def init_model(self, kvargs):
8484
init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}',
8585
rank=self.tp_rank,
8686
world_size=self.world_size,
87+
device_id=torch.device(f"cuda:{self.tp_rank}"),
8788
)
8889

8990
from lightllm.distributed import set_custom_reduce

0 commit comments

Comments
 (0)