Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from functools import cache
from typing import Optional, Union

import torch
Expand All @@ -15,6 +16,14 @@
logger = init_logger(__name__)


@cache
def is_rocm_aiter_custom_allreduce_enabled() -> bool:
"""Check if aiter custom allreduce is enabled for ROCm platform."""
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE


class CudaCommunicator(DeviceCommunicatorBase):

def __init__(self,
Expand All @@ -38,8 +47,12 @@ def __init__(self,
self.use_custom_allreduce = use_custom_allreduce

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
if is_rocm_aiter_custom_allreduce_enabled():
from aiter.dist.custom_all_reduce import CustomAllreduce
logger.info("Using aiter.dist.custom_all_reduce for ROCm platform")
else:
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa: E501
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
Expand Down
8 changes: 8 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_ROPE: bool = False
VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
Expand Down Expand Up @@ -709,6 +710,13 @@ def get_vllm_port() -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER_ROPE", "False").lower() in
("true", "1")),

# Whether to use aiter custom allreduce for ROCm platform.
# By default is disabled, uses vLLM built-in custom allreduce.
"VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE":
lambda:
(os.getenv("VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE", "True").lower() in
("true", "1")),

# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
Expand Down
Loading