| 
 | 1 | +__all__ = [  | 
 | 2 | +    "initialize_nccl_comm",  | 
 | 3 | +    "nccl_split",  | 
 | 4 | +    "nccl_allgather",  | 
 | 5 | +    "nccl_allreduce",  | 
 | 6 | +    "nccl_bcast",  | 
 | 7 | +    "nccl_asarray"  | 
 | 8 | +]  | 
 | 9 | + | 
 | 10 | +from enum import IntEnum  | 
 | 11 | +from mpi4py import MPI  | 
 | 12 | +import os  | 
 | 13 | +import numpy as np  | 
 | 14 | +import cupy as cp  | 
 | 15 | +import cupy.cuda.nccl as nccl  | 
 | 16 | + | 
 | 17 | +cupy_to_nccl_dtype = {  | 
 | 18 | +    "float32": nccl.NCCL_FLOAT32,  | 
 | 19 | +    "float64": nccl.NCCL_FLOAT64,  | 
 | 20 | +    "int32": nccl.NCCL_INT32,  | 
 | 21 | +    "int64": nccl.NCCL_INT64,  | 
 | 22 | +    "uint8": nccl.NCCL_UINT8,  | 
 | 23 | +    "int8": nccl.NCCL_INT8,  | 
 | 24 | +    "uint32": nccl.NCCL_UINT32,  | 
 | 25 | +    "uint64": nccl.NCCL_UINT64,  | 
 | 26 | +}  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +class NcclOp(IntEnum):  | 
 | 30 | +    SUM = nccl.NCCL_SUM  | 
 | 31 | +    PROD = nccl.NCCL_PROD  | 
 | 32 | +    MAX = nccl.NCCL_MAX  | 
 | 33 | +    MIN = nccl.NCCL_MIN  | 
 | 34 | + | 
 | 35 | + | 
 | 36 | +def mpi_op_to_nccl(mpi_op) -> NcclOp:  | 
 | 37 | +    """ Map MPI reduction operation to NCCL equivalent  | 
 | 38 | +
  | 
 | 39 | +    Parameters  | 
 | 40 | +    ----------  | 
 | 41 | +    mpi_op : :obj:`MPI.Op`  | 
 | 42 | +        A MPI reduction operation (e.g., MPI.SUM, MPI.PROD, MPI.MAX, MPI.MIN).  | 
 | 43 | +
  | 
 | 44 | +    Returns:  | 
 | 45 | +    -------  | 
 | 46 | +    NcclOp : :obj:`IntEnum`  | 
 | 47 | +        A corresponding NCCL reduction operation.  | 
 | 48 | +    """  | 
 | 49 | +    if mpi_op is MPI.SUM:  | 
 | 50 | +        return NcclOp.SUM  | 
 | 51 | +    elif mpi_op is MPI.PROD:  | 
 | 52 | +        return NcclOp.PROD  | 
 | 53 | +    elif mpi_op is MPI.MAX:  | 
 | 54 | +        return NcclOp.MAX  | 
 | 55 | +    elif mpi_op is MPI.MIN:  | 
 | 56 | +        return NcclOp.MIN  | 
 | 57 | +    else:  | 
 | 58 | +        raise ValueError(f"Unsupported MPI.Op for NCCL: {mpi_op}")  | 
 | 59 | + | 
 | 60 | + | 
 | 61 | +def initialize_nccl_comm() -> nccl.NcclCommunicator:  | 
 | 62 | +    """ Initialize NCCL world communicator for every GPU device  | 
 | 63 | +
  | 
 | 64 | +    Each GPU must be managed by exactly one MPI process.  | 
 | 65 | +    i.e. the number of MPI process launched must be equal to  | 
 | 66 | +    number of GPUs in communications  | 
 | 67 | +
  | 
 | 68 | +    Returns:  | 
 | 69 | +    -------  | 
 | 70 | +    nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 71 | +        A corresponding NCCL communicator  | 
 | 72 | +    """  | 
 | 73 | +    comm = MPI.COMM_WORLD  | 
 | 74 | +    rank = comm.Get_rank()  | 
 | 75 | +    size = comm.Get_size()  | 
 | 76 | +    device_id = int(  | 
 | 77 | +        os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")  | 
 | 78 | +        or rank % cp.cuda.runtime.getDeviceCount()  | 
 | 79 | +    )  | 
 | 80 | +    cp.cuda.Device(device_id).use()  | 
 | 81 | + | 
 | 82 | +    if rank == 0:  | 
 | 83 | +        with cp.cuda.Device(device_id):  | 
 | 84 | +            nccl_id_bytes = nccl.get_unique_id()  | 
 | 85 | +    else:  | 
 | 86 | +        nccl_id_bytes = None  | 
 | 87 | +    nccl_id_bytes = comm.bcast(nccl_id_bytes, root=0)  | 
 | 88 | + | 
 | 89 | +    nccl_comm = nccl.NcclCommunicator(size, nccl_id_bytes, rank)  | 
 | 90 | +    return nccl_comm  | 
 | 91 | + | 
 | 92 | + | 
 | 93 | +def nccl_split(mask) -> nccl.NcclCommunicator:  | 
 | 94 | +    """ NCCL-equivalent of MPI.Split()  | 
 | 95 | +
  | 
 | 96 | +    Splitting the communicator into multiple NCCL subcommunicators  | 
 | 97 | +
  | 
 | 98 | +    Parameters  | 
 | 99 | +    ----------  | 
 | 100 | +    mask : :obj:`list`  | 
 | 101 | +        Mask defining subsets of ranks to consider when performing 'global'  | 
 | 102 | +        operations on the distributed array such as dot product or norm.  | 
 | 103 | +
  | 
 | 104 | +    Returns:  | 
 | 105 | +    -------  | 
 | 106 | +    sub_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 107 | +        Subcommunicator according to mask  | 
 | 108 | +    """  | 
 | 109 | +    comm = MPI.COMM_WORLD  | 
 | 110 | +    rank = comm.Get_rank()  | 
 | 111 | +    sub_comm = comm.Split(color=mask[rank], key=rank)  | 
 | 112 | + | 
 | 113 | +    sub_rank = sub_comm.Get_rank()  | 
 | 114 | +    sub_size = sub_comm.Get_size()  | 
 | 115 | + | 
 | 116 | +    if sub_rank == 0:  | 
 | 117 | +        nccl_id_bytes = nccl.get_unique_id()  | 
 | 118 | +    else:  | 
 | 119 | +        nccl_id_bytes = None  | 
 | 120 | +    nccl_id_bytes = sub_comm.bcast(nccl_id_bytes, root=0)  | 
 | 121 | +    sub_comm = nccl.NcclCommunicator(sub_size, nccl_id_bytes, sub_rank)  | 
 | 122 | + | 
 | 123 | +    return sub_comm  | 
 | 124 | + | 
 | 125 | + | 
 | 126 | +def nccl_allgather(nccl_comm, send_buf, recv_buf=None) -> cp.ndarray:  | 
 | 127 | +    """ NCCL equivalent of MPI_Allgather. Gathers data from all GPUs  | 
 | 128 | +    and distributes the concatenated result to all participants.  | 
 | 129 | +
  | 
 | 130 | +    Parameters  | 
 | 131 | +    ----------  | 
 | 132 | +    nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 133 | +        The NCCL communicator over which data will be gathered.  | 
 | 134 | +    send_buf : :obj:`cupy.ndarray` or array-like  | 
 | 135 | +        The data buffer from the local GPU to be sent.  | 
 | 136 | +    recv_buf : :obj:`cupy.ndarray`, optional  | 
 | 137 | +        The buffer to receive data from all GPUs. If None, a new  | 
 | 138 | +        buffer will be allocated with the appropriate shape.  | 
 | 139 | +
  | 
 | 140 | +    Returns  | 
 | 141 | +    -------  | 
 | 142 | +    recv_buf : :obj:`cupy.ndarray`  | 
 | 143 | +        A buffer containing the gathered data from all GPUs.  | 
 | 144 | +    """  | 
 | 145 | +    send_buf = (  | 
 | 146 | +        send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)  | 
 | 147 | +    )  | 
 | 148 | +    if recv_buf is None:  | 
 | 149 | +        recv_buf = cp.zeros(  | 
 | 150 | +            MPI.COMM_WORLD.Get_size() * send_buf.size,  | 
 | 151 | +            dtype=send_buf.dtype,  | 
 | 152 | +        )  | 
 | 153 | +    nccl_comm.allGather(  | 
 | 154 | +        send_buf.data.ptr,  | 
 | 155 | +        recv_buf.data.ptr,  | 
 | 156 | +        send_buf.size,  | 
 | 157 | +        cupy_to_nccl_dtype[str(send_buf.dtype)],  | 
 | 158 | +        cp.cuda.Stream.null.ptr,  | 
 | 159 | +    )  | 
 | 160 | +    return recv_buf  | 
 | 161 | + | 
 | 162 | + | 
 | 163 | +def nccl_allreduce(nccl_comm, send_buf, recv_buf=None, op: MPI.Op = MPI.SUM) -> cp.ndarray:  | 
 | 164 | +    """ NCCL equivalent of MPI_Allreduce. Applies a reduction operation  | 
 | 165 | +    (e.g., sum, max) across all GPUs and distributes the result.  | 
 | 166 | +
  | 
 | 167 | +    Parameters  | 
 | 168 | +    ----------  | 
 | 169 | +    nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 170 | +        The NCCL communicator used for collective communication.  | 
 | 171 | +    send_buf : :obj:`cupy.ndarray` or array-like  | 
 | 172 | +        The data buffer from the local GPU to be reduced.  | 
 | 173 | +    recv_buf : :obj:`cupy.ndarray`, optional  | 
 | 174 | +        The buffer to store the result of the reduction. If None,  | 
 | 175 | +        a new buffer will be allocated with the appropriate shape.  | 
 | 176 | +    op : :obj:mpi4py.MPI.Op, optional  | 
 | 177 | +        The reduction operation to apply. Defaults to MPI.SUM.  | 
 | 178 | +
  | 
 | 179 | +    Returns  | 
 | 180 | +    -------  | 
 | 181 | +    recv_buf : :obj:`cupy.ndarray`  | 
 | 182 | +        A buffer containing the result of the reduction, broadcasted  | 
 | 183 | +        to all GPUs.  | 
 | 184 | +    """  | 
 | 185 | +    send_buf = (  | 
 | 186 | +        send_buf if isinstance(send_buf, cp.ndarray) else cp.asarray(send_buf)  | 
 | 187 | +    )  | 
 | 188 | +    if recv_buf is None:  | 
 | 189 | +        recv_buf = cp.zeros(send_buf.size, dtype=send_buf.dtype)  | 
 | 190 | + | 
 | 191 | +    nccl_comm.allReduce(  | 
 | 192 | +        send_buf.data.ptr,  | 
 | 193 | +        recv_buf.data.ptr,  | 
 | 194 | +        send_buf.size,  | 
 | 195 | +        cupy_to_nccl_dtype[str(send_buf.dtype)],  | 
 | 196 | +        mpi_op_to_nccl(op),  | 
 | 197 | +        cp.cuda.Stream.null.ptr,  | 
 | 198 | +    )  | 
 | 199 | +    return recv_buf  | 
 | 200 | + | 
 | 201 | + | 
 | 202 | +def nccl_bcast(nccl_comm, local_array, index, value) -> None:  | 
 | 203 | +    """ NCCL equivalent of MPI_Bcast. Broadcasts a single value at the given index  | 
 | 204 | +    from the root GPU (rank 0) to all other GPUs.  | 
 | 205 | +
  | 
 | 206 | +    Parameters  | 
 | 207 | +    ----------  | 
 | 208 | +    nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 209 | +        The NCCL communicator used for collective communication.  | 
 | 210 | +    local_array : :obj:`cupy.ndarray`  | 
 | 211 | +        The local array on each GPU. The value at `index` will be broadcasted.  | 
 | 212 | +    index : :obj:`int`  | 
 | 213 | +        The index in the array to be broadcasted.  | 
 | 214 | +    value : :obj:`scalar`  | 
 | 215 | +        The value to broadcast (only used by the root GPU, rank 0).  | 
 | 216 | +
  | 
 | 217 | +    Returns  | 
 | 218 | +    -------  | 
 | 219 | +    None  | 
 | 220 | +    """  | 
 | 221 | +    if nccl_comm.rank_id() == 0:  | 
 | 222 | +        local_array[index] = value  | 
 | 223 | +    nccl_comm.bcast(  | 
 | 224 | +        local_array[index].data.ptr,  | 
 | 225 | +        local_array[index].size,  | 
 | 226 | +        cupy_to_nccl_dtype[str(local_array[index].dtype)],  | 
 | 227 | +        0,  | 
 | 228 | +        cp.cuda.Stream.null.ptr,  | 
 | 229 | +    )  | 
 | 230 | + | 
 | 231 | + | 
 | 232 | +def nccl_asarray(nccl_comm, local_array, local_shapes, axis) -> cp.ndarray:  | 
 | 233 | +    """Global view of the array  | 
 | 234 | +
  | 
 | 235 | +    Gather all local GPU arrays into a single global array via NCCL all-gather.  | 
 | 236 | +
  | 
 | 237 | +    Parameters  | 
 | 238 | +    ----------  | 
 | 239 | +    nccl_comm : :obj:`cupy.cuda.nccl.NcclCommunicator`  | 
 | 240 | +        The NCCL communicator used for collective communication.  | 
 | 241 | +    local_array : :obj:`cupy.ndarray`  | 
 | 242 | +        The local array on the current GPU.  | 
 | 243 | +    local_shapes : :obj:`list`  | 
 | 244 | +        A list of shapes for each GPU local array (used to trim padding).  | 
 | 245 | +    axis : :obj:`int`  | 
 | 246 | +        The axis along which to concatenate the gathered arrays.  | 
 | 247 | +
  | 
 | 248 | +    Returns  | 
 | 249 | +    -------  | 
 | 250 | +    final_array : :obj:`cupy.ndarray`  | 
 | 251 | +        Global array gathered from all GPUs and concatenated along `axis`.  | 
 | 252 | +
  | 
 | 253 | +    Notes  | 
 | 254 | +    -----  | 
 | 255 | +    NCCL's allGather requires the sending buffer to have the same size for every device.  | 
 | 256 | +    Therefore, the padding is required when the array is not evenly partitioned across  | 
 | 257 | +    all the ranks. The padding is applied such that the sending buffer has the size of  | 
 | 258 | +    each dimension corresponding to the max possible size of that dimension.  | 
 | 259 | +    """  | 
 | 260 | +    sizes_each_dim = list(zip(*local_shapes))  | 
 | 261 | + | 
 | 262 | +    send_shape = tuple(map(max, sizes_each_dim))  | 
 | 263 | +    pad_size = [  | 
 | 264 | +        (0, s_shape - l_shape) for s_shape, l_shape in zip(send_shape, local_array.shape)  | 
 | 265 | +    ]  | 
 | 266 | + | 
 | 267 | +    send_buf = cp.pad(  | 
 | 268 | +        local_array, pad_size, mode="constant", constant_values=0  | 
 | 269 | +    )  | 
 | 270 | + | 
 | 271 | +    # NCCL recommends to use one MPI Process per GPU and so size of receiving buffer can be inferred  | 
 | 272 | +    ndev = len(local_shapes)  | 
 | 273 | +    recv_buf = cp.zeros(ndev * send_buf.size, dtype=send_buf.dtype)  | 
 | 274 | +    nccl_allgather(nccl_comm, send_buf, recv_buf)  | 
 | 275 | + | 
 | 276 | +    # extract an individual array from each device  | 
 | 277 | +    chunk_size = np.prod(send_shape)  | 
 | 278 | +    chunks = [  | 
 | 279 | +        recv_buf[i * chunk_size:(i + 1) * chunk_size] for i in range(ndev)  | 
 | 280 | +    ]  | 
 | 281 | + | 
 | 282 | +    # Remove padding from each array: the padded value may appear somewhere  | 
 | 283 | +    # in the middle of the flat array and thus the reshape and slicing for each dimension is required  | 
 | 284 | +    for i in range(ndev):  | 
 | 285 | +        slicing = tuple(slice(0, end) for end in local_shapes[i])  | 
 | 286 | +        chunks[i] = chunks[i].reshape(send_shape)[slicing]  | 
 | 287 | +    # combine back to single global array  | 
 | 288 | +    return cp.concatenate(chunks, axis=axis)  | 
0 commit comments