diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 64cafa40a0a7..fd286b1854f5 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from functools import cache -from typing import Optional, Union +from typing import Callable, Optional, Union import torch from torch.distributed import ProcessGroup @@ -16,13 +15,167 @@ logger = init_logger(__name__) -@cache -def is_rocm_aiter_custom_allreduce_enabled() -> bool: - """Check if aiter custom allreduce is enabled for ROCm platform.""" - return current_platform.is_rocm() \ - and envs.VLLM_ROCM_USE_AITER \ - and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE +# ROCm allreduce dispatcher that dispatches the +# performant allreduce implementation based on +# the available implementations and payload size +# of input tensor. It only supports AMD ROCm platform. +class ROCmAllreduceDispatcher: + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + ca_comm = None, + pynccl_comm = None): + self.process_group = group + self.device = device + self.cur_device_arch = self._get_current_device_arch() + + self.tp_size = torch.distributed.get_world_size(group=self.process_group) + + # include the MI300, MI308, MI350, MI355 + self.supported_device_archs = ["MI30X", "MI35X"] + + # dispatch thresholds by tp_size: + # (aiter_ca_threshold(KB), vllm_qr_threshold) + self.MI35x_thresholds = { + 2: 1024, + 4: 4096, + 8: 8192, + } + self.MI30x_thresholds = { + 2: 1024, + 4: 2048, + 8: 2048, + } + + # allreduce naming : associated allreduce impl, allreduce check impl + self.available_allreduce_impls: dict[str, (Callable, Optional[Callable])] = {} + + self.available_allreduce_impls["pynccl"] = (pynccl_comm.all_reduce, None) + self.fallback_impl = pynccl_comm.all_reduce + + if ca_comm is not None: + self.available_allreduce_impls["vllm_ca"] = \ + (ca_comm.custom_all_reduce, ca_comm.should_custom_ar) + + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) + self.qr_comm = QuickAllReduce(group=self.process_group, + device=self.device) + if self.qr_comm is not None: + self.available_allreduce_impls["vllm_qr"] = \ + (self.qr_comm.quick_all_reduce, self.qr_comm.should_quick_allreduce) + + # Initialize a custom all-reduce implementation from aiter. + if self._is_aiter_custom_allreduce_available(): + from aiter.dist.custom_all_reduce import CustomAllreduce \ + as AiterCustomAllreduce + self.aiter_ca_comm = AiterCustomAllreduce( + group=self.process_group, + device=self.device, + ) + if self.aiter_ca_comm is not None: + self.available_allreduce_impls["aiter_ca"] = \ + (self.aiter_ca_comm.custom_all_reduce, \ + self.aiter_ca_comm.should_custom_ar) + + def _is_aiter_custom_allreduce_available(self) -> bool: + """Check if aiter is enabled for ROCm platform.""" + if not envs.VLLM_ROCM_USE_AITER: + return False + + try: + from aiter.dist.custom_all_reduce import CustomAllreduce + return True + except ImportError: + return False + + def _get_current_device_arch(self) -> str: + """Get the device micro architecture number of the current device.""" + # TODO(zejun): Add more device architectures + device_arch = torch.cuda.get_device_properties("cuda").gcnArchName + if "gfx95" in device_arch: + return "MI35X" + elif "gfx94" in device_arch: + return "MI30X" + elif "gfx11" in device_arch: + return "RX" + else: + return device_arch + + def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool: + if impl_name not in self.available_allreduce_impls: + return False + return self.available_allreduce_impls[impl_name][1](input_) + + def _dispatch_mi30x(self, + input_: torch.Tensor, + payload_size_KB: int, + tp_size: int) -> Callable: + """Dispatch implementation for MI35X architecture.""" + if tp_size not in self.MI30x_thresholds: + return self.fallback_impl + + threshold = self.MI30x_thresholds[tp_size] + + if payload_size_KB <= threshold and \ + self._should_allreduce(input_, "vllm_ca"): + return self.available_allreduce_impls["vllm_ca"][0] + + if self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + return self.fallback_impl + + def _dispatch_mi35x(self, + input_: torch.Tensor, + payload_size_KB: int, + tp_size: int) -> Callable: + """Dispatch implementation for MI35X architecture.""" + if tp_size not in self.MI35x_thresholds: + return self.fallback_impl + + threshold = self.MI35x_thresholds[tp_size] + + if payload_size_KB <= threshold and \ + self._should_allreduce(input_, "aiter_ca"): + return self.available_allreduce_impls["aiter_ca"][0] + + if self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + + return self.fallback_impl + + def _dispatch_impl(self, + input_: torch.Tensor, + payload_size_KB: int, + device_arch: str, + tp_size: int) -> Callable: + if device_arch not in self.supported_device_archs: + logger.debug(f"Device architecture {device_arch} not supported, using pynccl") + return self.fallback_impl + + if device_arch == "MI35X": + return self._dispatch_mi35x(input_, payload_size_KB, tp_size) + elif device_arch == "MI30X": + return self._dispatch_mi30x(input_, payload_size_KB, tp_size) + else: + # for other devices, fallback to pynccl + return self.fallback_impl + + def dispatch(self, input_: torch.Tensor) -> Callable: + """Dispatch the allreduce implementation""" + # unit: KB + payload_size = int(input_.numel() * input_.element_size() / 1024.0) + op = self._dispatch_impl(input_, + payload_size, + self.cur_device_arch, + self.tp_size) + return op class CudaCommunicator(DeviceCommunicatorBase): @@ -55,8 +208,8 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -66,7 +219,11 @@ def __init__(self, ) self.ca_comm: Optional[CustomAllreduce] = None - self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None + + # Initialize a custom all-reduce dispatcher for ROCm platform + self.rocm_allreduce_dispatcher: Optional[ROCmAllreduceDispatcher] = None + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -75,13 +232,19 @@ def __init__(self, ) if current_platform.is_rocm(): - # Initialize a custom quick all-reduce implementation for AMD. - # Quick reduce is designed as a complement to custom allreduce. - # Based on quickreduce (https://github.com/mk1-project/quickreduce). - # If it's a rocm, 'use_custom_allreduce==True' means it must - # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.rocm_allreduce_dispatcher = \ + ROCmAllreduceDispatcher(group=self.cpu_group, + device=self.device, + ca_comm=self.ca_comm, + pynccl_comm=self.pynccl_comm) + logger.info("Initializing ROCm allreduce dispatcher.") + + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -104,23 +267,29 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.quick_all_reduce(input_) - assert out is not None - return out - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + if current_platform.is_rocm() and self.rocm_allreduce_dispatcher is not None: + op = self.rocm_allreduce_dispatcher.dispatch(input_) + logger.debug(f"ROCm allreduce dispatcher dispatched: {op}") + out = op(input_) + else: + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out + + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + + # fallback to the default all-reduce using PyTorch. if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing.