diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 0a29229a7d77..f62a80e34fb9 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from functools import cache -from typing import Optional, Union +from typing import Callable, Optional, Union import torch from torch.distributed import ProcessGroup @@ -23,6 +23,140 @@ def is_rocm_aiter_custom_allreduce_enabled() -> bool: and envs.VLLM_ROCM_USE_AITER \ and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE +# ROCm allreduce dispatcher that dispatches the +# performant allreduce implementation based on +# the available implementations and payload size +# of input tensor. It only supports AMD ROCm platform. +class ROCmAllreduceDispatcher: + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + ca_comm = None, + pynccl_comm = None): + self.process_group = group + self.device = device + self.cur_device_arch = self._get_current_device_arch() + self.supported_device_archs = ["gfx94", "gfx95"] + + self.tp_size = torch.distributed.get_world_size(group=self.process_group) + + # dispatch thresholds(unit: KB) by tp_size: + self.gfx95_thresholds = { + 2: 512, + 4: 2048, + 8: 32768, + } + self.gfx94_thresholds = { + 2: 2048, + 4: 4096, + 8: 8192, + } + + # allreduce naming : associated allreduce impl, allreduce check impl + self.available_allreduce_impls: dict[str, (Callable, Optional[Callable])] = {} + + self.available_allreduce_impls["pynccl"] = (pynccl_comm.all_reduce, None) + self.fallback_impl = pynccl_comm.all_reduce + + if ca_comm is not None: + self.available_allreduce_impls["vllm_ca"] = \ + (ca_comm.custom_all_reduce, ca_comm.should_custom_ar) + + # Initialize a custom quick all-reduce implementation for AMD. + # Quick reduce is designed as a complement to custom allreduce. + # Based on quickreduce (https://github.com/mk1-project/quickreduce). + # If it's a rocm, 'use_custom_allreduce==True' means it must + # currently be an MI300 series. + from vllm.distributed.device_communicators.quick_all_reduce import ( + QuickAllReduce) + self.qr_comm = QuickAllReduce(group=self.process_group, + device=self.device) + if self.qr_comm is not None: + self.available_allreduce_impls["vllm_qr"] = \ + (self.qr_comm.quick_all_reduce, self.qr_comm.should_quick_allreduce) + + def _get_current_device_arch(self) -> str: + """Get the device micro architecture number of the current device.""" + # TODO(zejun): Add more device architectures + device_arch = torch.cuda.get_device_properties("cuda").gcnArchName + if "gfx95" in device_arch: + return "gfx95" + elif "gfx94" in device_arch: + return "gfx94" + else: + return device_arch + + def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool: + if impl_name not in self.available_allreduce_impls: + return False + return self.available_allreduce_impls[impl_name][1](input_) + + def _dispatch_gfx94(self, + input_: torch.Tensor, + payload_size_KB: int, + tp_size: int) -> Callable: + """Dispatch implementation for gfx94 architecture.""" + if tp_size not in self.gfx94_thresholds: + return self.fallback_impl + + threshold = self.gfx94_thresholds[tp_size] + + if payload_size_KB <= threshold: + if self._should_allreduce(input_, "vllm_ca"): + return self.available_allreduce_impls["vllm_ca"][0] + else: + return self.fallback_impl + elif self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + else: + return self.fallback_impl + + def _dispatch_gfx95(self, + input_: torch.Tensor, + payload_size_KB: int, + tp_size: int) -> Callable: + """Dispatch implementation for gfx95 architecture.""" + if tp_size not in self.gfx95_thresholds: + return self.fallback_impl + + threshold = self.gfx95_thresholds[tp_size] + + if payload_size_KB <= threshold: + if self._should_allreduce(input_, "vllm_ca"): + return self.available_allreduce_impls["vllm_ca"][0] + else: + return self.fallback_impl + elif self._should_allreduce(input_, "vllm_qr"): + return self.available_allreduce_impls["vllm_qr"][0] + else: + return self.fallback_impl + + def _dispatch_impl(self, + input_: torch.Tensor, + payload_size_KB: int, + device_arch: str, + tp_size: int) -> Callable: + if device_arch not in self.supported_device_archs: + logger.debug(f"Device architecture {device_arch} not supported, using pynccl") + return self.fallback_impl + + if device_arch == "gfx95": + return self._dispatch_gfx95(input_, payload_size_KB, tp_size) + elif device_arch == "gfx94": + return self._dispatch_gfx94(input_, payload_size_KB, tp_size) + else: + # for other devices, fallback to pynccl + return self.fallback_impl + + def dispatch(self, input_: torch.Tensor) -> Callable: + """Dispatch the allreduce implementation""" + # unit: KB + payload_size = int(input_.numel() * input_.element_size() / 1024.0) + op = self._dispatch_impl(input_, + payload_size, + self.cur_device_arch, + self.tp_size) + return op class CudaCommunicator(DeviceCommunicatorBase): @@ -55,8 +189,6 @@ def __init__(self, CustomAllreduce) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator) - from vllm.distributed.device_communicators.quick_all_reduce import ( - QuickAllReduce) from vllm.distributed.device_communicators.symm_mem import ( SymmMemCommunicator) @@ -68,8 +200,11 @@ def __init__(self, ) self.ca_comm: Optional[CustomAllreduce] = None - self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None + + # Initialize a custom all-reduce dispatcher for ROCm platform + self.rocm_allreduce_dispatcher: Optional[ROCmAllreduceDispatcher] = None + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -78,13 +213,13 @@ def __init__(self, ) if current_platform.is_rocm(): - # Initialize a custom quick all-reduce implementation for AMD. - # Quick reduce is designed as a complement to custom allreduce. - # Based on quickreduce (https://github.com/mk1-project/quickreduce). - # If it's a rocm, 'use_custom_allreduce==True' means it must - # currently be an MI300 series. - self.qr_comm = QuickAllReduce(group=self.cpu_group, - device=self.device) + self.rocm_allreduce_dispatcher = \ + ROCmAllreduceDispatcher(group=self.cpu_group, + device=self.device, + ca_comm=self.ca_comm, + pynccl_comm=self.pynccl_comm) + logger.info("Initializing ROCm allreduce dispatcher.") + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, @@ -113,29 +248,29 @@ def __init__(self, raise ValueError(f"Unknown all2all backend: {all2all_backend}") def all_reduce(self, input_): - # always try quick reduce first, then custom allreduce, - # and then pynccl. (quick reduce just for ROCM MI3*) - qr_comm = self.qr_comm - if qr_comm is not None and not qr_comm.disabled and \ - qr_comm.should_quick_allreduce(input_): - out = qr_comm.quick_all_reduce(input_) - assert out is not None - return out - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and \ - ca_comm.should_custom_ar(input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - symm_mem_comm = self.symm_mem_comm - if symm_mem_comm is not None and \ - symm_mem_comm.should_use_symm_mem(input_): - out = symm_mem_comm.all_reduce(input_) - assert out is not None - return out - pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + if current_platform.is_rocm() and self.rocm_allreduce_dispatcher is not None: + op = self.rocm_allreduce_dispatcher.dispatch(input_) + logger.debug(f"ROCm allreduce dispatcher dispatched: {op}") + out = op(input_) + else: + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out + + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_) + + # fallback to the default all-reduce using PyTorch. if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing.