File tree Expand file tree Collapse file tree 2 files changed +23
-2
lines changed
distributed/device_communicators Expand file tree Collapse file tree 2 files changed +23
-2
lines changed Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ from functools import cache
45from typing import Optional , Union
56
67import torch
1516logger = init_logger (__name__ )
1617
1718
19+ @cache
20+ def is_rocm_aiter_custom_allreduce_enabled () -> bool :
21+ """Check if aiter custom allreduce is enabled for ROCm platform."""
22+ return current_platform .is_rocm () \
23+ and envs .VLLM_ROCM_USE_AITER \
24+ and envs .VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE
25+
26+
1827class CudaCommunicator (DeviceCommunicatorBase ):
1928
2029 def __init__ (self ,
@@ -38,8 +47,12 @@ def __init__(self,
3847 self .use_custom_allreduce = use_custom_allreduce
3948
4049 # lazy import to avoid documentation build error
41- from vllm .distributed .device_communicators .custom_all_reduce import (
42- CustomAllreduce )
50+ if is_rocm_aiter_custom_allreduce_enabled ():
51+ from aiter .dist .custom_all_reduce import CustomAllreduce
52+ logger .info ("Using aiter.dist.custom_all_reduce for ROCm platform" )
53+ else :
54+ from vllm .distributed .device_communicators .custom_all_reduce import ( # noqa: E501
55+ CustomAllreduce )
4356 from vllm .distributed .device_communicators .pynccl import (
4457 PyNcclCommunicator )
4558 from vllm .distributed .device_communicators .quick_all_reduce import (
Original file line number Diff line number Diff line change 9595 VLLM_ROCM_USE_AITER_MLA : bool = True
9696 VLLM_ROCM_USE_AITER_MHA : bool = True
9797 VLLM_ROCM_USE_AITER_ROPE : bool = False
98+ VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE : bool = True
9899 VLLM_ROCM_USE_SKINNY_GEMM : bool = True
99100 VLLM_ROCM_FP8_PADDING : bool = True
100101 VLLM_ROCM_MOE_PADDING : bool = True
@@ -709,6 +710,13 @@ def get_vllm_port() -> Optional[int]:
709710 lambda : (os .getenv ("VLLM_ROCM_USE_AITER_ROPE" , "False" ).lower () in
710711 ("true" , "1" )),
711712
713+ # Whether to use aiter custom allreduce for ROCm platform.
714+ # By default is disabled, uses vLLM built-in custom allreduce.
715+ "VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE" :
716+ lambda :
717+ (os .getenv ("VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE" , "True" ).lower () in
718+ ("true" , "1" )),
719+
712720 # use rocm skinny gemms
713721 "VLLM_ROCM_USE_SKINNY_GEMM" :
714722 lambda : (os .getenv ("VLLM_ROCM_USE_SKINNY_GEMM" , "True" ).lower () in
You can’t perform that action at this time.
0 commit comments