Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 169 additions & 34 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from functools import cache
from typing import Optional, Union
from typing import Callable, Optional, Union

import torch
from torch.distributed import ProcessGroup
Expand All @@ -23,6 +23,140 @@
and envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE

# ROCm allreduce dispatcher that dispatches the
# performant allreduce implementation based on
# the available implementations and payload size
# of input tensor. It only supports AMD ROCm platform.
class ROCmAllreduceDispatcher:
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
ca_comm = None,
pynccl_comm = None):
self.process_group = group
self.device = device
self.cur_device_arch = self._get_current_device_arch()
self.supported_device_archs = ["gfx94", "gfx95"]

self.tp_size = torch.distributed.get_world_size(group=self.process_group)

# dispatch thresholds(unit: KB) by tp_size:
self.gfx95_thresholds = {
2: 512,
4: 2048,
8: 32768,
}
self.gfx94_thresholds = {
2: 2048,
4: 4096,
8: 8192,
}

# allreduce naming : associated allreduce impl, allreduce check impl
self.available_allreduce_impls: dict[str, (Callable, Optional[Callable])] = {}

self.available_allreduce_impls["pynccl"] = (pynccl_comm.all_reduce, None)
self.fallback_impl = pynccl_comm.all_reduce

if ca_comm is not None:
self.available_allreduce_impls["vllm_ca"] = \
(ca_comm.custom_all_reduce, ca_comm.should_custom_ar)

# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
self.qr_comm = QuickAllReduce(group=self.process_group,
device=self.device)
if self.qr_comm is not None:
self.available_allreduce_impls["vllm_qr"] = \
(self.qr_comm.quick_all_reduce, self.qr_comm.should_quick_allreduce)

def _get_current_device_arch(self) -> str:
"""Get the device micro architecture number of the current device."""
# TODO(zejun): Add more device architectures
device_arch = torch.cuda.get_device_properties("cuda").gcnArchName

Check failure on line 81 in vllm/distributed/device_communicators/cuda_communicator.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/device_communicators/cuda_communicator.py:81:81: E501 Line too long (84 > 80)
if "gfx95" in device_arch:
return "gfx95"
elif "gfx94" in device_arch:
return "gfx94"
else:
return device_arch

def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool:
if impl_name not in self.available_allreduce_impls:
return False
return self.available_allreduce_impls[impl_name][1](input_)

def _dispatch_gfx94(self,
input_: torch.Tensor,
payload_size_KB: int,
tp_size: int) -> Callable:
"""Dispatch implementation for gfx94 architecture."""
if tp_size not in self.gfx94_thresholds:
return self.fallback_impl

threshold = self.gfx94_thresholds[tp_size]

if payload_size_KB <= threshold:
if self._should_allreduce(input_, "vllm_ca"):
return self.available_allreduce_impls["vllm_ca"][0]
else:
return self.fallback_impl
elif self._should_allreduce(input_, "vllm_qr"):
return self.available_allreduce_impls["vllm_qr"][0]
else:
return self.fallback_impl

def _dispatch_gfx95(self,
input_: torch.Tensor,
payload_size_KB: int,
tp_size: int) -> Callable:
"""Dispatch implementation for gfx95 architecture."""
if tp_size not in self.gfx95_thresholds:
return self.fallback_impl

threshold = self.gfx95_thresholds[tp_size]

if payload_size_KB <= threshold:
if self._should_allreduce(input_, "vllm_ca"):
return self.available_allreduce_impls["vllm_ca"][0]
else:
return self.fallback_impl
elif self._should_allreduce(input_, "vllm_qr"):
return self.available_allreduce_impls["vllm_qr"][0]
else:
return self.fallback_impl

def _dispatch_impl(self,
input_: torch.Tensor,
payload_size_KB: int,
device_arch: str,
tp_size: int) -> Callable:
if device_arch not in self.supported_device_archs:

Check failure on line 139 in vllm/distributed/device_communicators/cuda_communicator.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/distributed/device_communicators/cuda_communicator.py:139:17: G004 Logging statement uses f-string
logger.debug(f"Device architecture {device_arch} not supported, using pynccl")
return self.fallback_impl

if device_arch == "gfx95":
return self._dispatch_gfx95(input_, payload_size_KB, tp_size)
elif device_arch == "gfx94":
return self._dispatch_gfx94(input_, payload_size_KB, tp_size)
else:
# for other devices, fallback to pynccl
return self.fallback_impl

def dispatch(self, input_: torch.Tensor) -> Callable:
"""Dispatch the allreduce implementation"""
# unit: KB
payload_size = int(input_.numel() * input_.element_size() / 1024.0)
op = self._dispatch_impl(input_,
payload_size,
self.cur_device_arch,
self.tp_size)
return op

class CudaCommunicator(DeviceCommunicatorBase):

Expand Down Expand Up @@ -55,8 +189,6 @@
CustomAllreduce)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)

Expand All @@ -68,8 +200,11 @@
)

self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None

# Initialize a custom all-reduce dispatcher for ROCm platform
self.rocm_allreduce_dispatcher: Optional[ROCmAllreduceDispatcher] = None

if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
Expand All @@ -78,13 +213,13 @@
)

if current_platform.is_rocm():
# Initialize a custom quick all-reduce implementation for AMD.
# Quick reduce is designed as a complement to custom allreduce.
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
# If it's a rocm, 'use_custom_allreduce==True' means it must
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
self.rocm_allreduce_dispatcher = \
ROCmAllreduceDispatcher(group=self.cpu_group,
device=self.device,
ca_comm=self.ca_comm,
pynccl_comm=self.pynccl_comm)
logger.info("Initializing ROCm allreduce dispatcher.")

if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
Expand Down Expand Up @@ -113,29 +248,29 @@
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

def all_reduce(self, input_):
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
if qr_comm is not None and not qr_comm.disabled and \
qr_comm.should_quick_allreduce(input_):
out = qr_comm.quick_all_reduce(input_)
assert out is not None
return out
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if current_platform.is_rocm() and self.rocm_allreduce_dispatcher is not None:
op = self.rocm_allreduce_dispatcher.dispatch(input_)
logger.debug(f"ROCm allreduce dispatcher dispatched: {op}")
out = op(input_)

Check failure on line 254 in vllm/distributed/device_communicators/cuda_communicator.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/distributed/device_communicators/cuda_communicator.py:254:26: G004 Logging statement uses f-string
else:
ca_comm = self.ca_comm
if ca_comm is not None and not ca_comm.disabled and \
ca_comm.should_custom_ar(input_):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out

pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)

# fallback to the default all-reduce using PyTorch.
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
Expand Down
Loading