Skip to content

Commit e2fa100

Browse files
[FEAT] Add custom allreduce from AITER to vllm and (#629)
control it by the env flag VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE (default: True) Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
1 parent ea7d874 commit e2fa100

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from functools import cache
45
from typing import Optional, Union
56

67
import torch
@@ -15,6 +16,14 @@
1516
logger = init_logger(__name__)
1617

1718

19+
@cache
20+
def is_rocm_aiter_custom_allreduce_enabled() -> bool:
21+
"""Check if aiter custom allreduce is enabled for ROCm platform."""
22+
return current_platform.is_rocm() \
23+
and envs.VLLM_ROCM_USE_AITER \
24+
and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE
25+
26+
1827
class CudaCommunicator(DeviceCommunicatorBase):
1928

2029
def __init__(self,
@@ -38,8 +47,12 @@ def __init__(self,
3847
self.use_custom_allreduce = use_custom_allreduce
3948

4049
# lazy import to avoid documentation build error
41-
from vllm.distributed.device_communicators.custom_all_reduce import (
42-
CustomAllreduce)
50+
if is_rocm_aiter_custom_allreduce_enabled():
51+
from aiter.dist.custom_all_reduce import CustomAllreduce
52+
logger.info("Using aiter.dist.custom_all_reduce for ROCm platform")
53+
else:
54+
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
55+
CustomAllreduce)
4356
from vllm.distributed.device_communicators.pynccl import (
4457
PyNcclCommunicator)
4558
from vllm.distributed.device_communicators.quick_all_reduce import (

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
VLLM_ROCM_USE_AITER_MLA: bool = True
9696
VLLM_ROCM_USE_AITER_MHA: bool = True
9797
VLLM_ROCM_USE_AITER_ROPE: bool = False
98+
VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True
9899
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
99100
VLLM_ROCM_FP8_PADDING: bool = True
100101
VLLM_ROCM_MOE_PADDING: bool = True
@@ -709,6 +710,13 @@ def get_vllm_port() -> Optional[int]:
709710
lambda: (os.getenv("VLLM_ROCM_USE_AITER_ROPE", "False").lower() in
710711
("true", "1")),
711712

713+
# Whether to use aiter custom allreduce for ROCm platform.
714+
# By default is disabled, uses vLLM built-in custom allreduce.
715+
"VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE":
716+
lambda:
717+
(os.getenv("VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE", "True").lower() in
718+
("true", "1")),
719+
712720
# use rocm skinny gemms
713721
"VLLM_ROCM_USE_SKINNY_GEMM":
714722
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in

0 commit comments

Comments
 (0)