|
| 1 | +/* |
| 2 | + * Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All rights reserved. |
| 3 | + */ |
| 4 | + |
| 5 | +#include "config.h" |
| 6 | + |
| 7 | +#include "test-common.h" |
| 8 | + |
| 9 | +#include <deque> |
| 10 | +#include <vector> |
| 11 | + |
| 12 | +static inline ncclResult_t |
| 13 | +poll_request_completion(ncclGin_v11_t *extGin, std::deque<void *> &request_deque, void *collComm) |
| 14 | +{ |
| 15 | + /* Wait for outstanding requests */ |
| 16 | + int done = 0; |
| 17 | + OFINCCLCHECK(extGin->test(collComm, request_deque.front(), &done)); |
| 18 | + if (done) { |
| 19 | + request_deque.pop_front(); |
| 20 | + } else { |
| 21 | + /** |
| 22 | + * Note: The NCCL GIN proxy thread will call ginProgress repeatedly |
| 23 | + * until the communicator is closed. We emulate that throughout this |
| 24 | + * test by ensuring each rank calls `ginProgress` until the barrier |
| 25 | + * before verification is reached. |
| 26 | + */ |
| 27 | + OFINCCLCHECK(extGin->ginProgress(collComm)); |
| 28 | + } |
| 29 | + return ncclSuccess; |
| 30 | +} |
| 31 | + |
| 32 | +static inline ncclResult_t alloc_and_reg_buff(ncclGin_v11_t *extGin, void *collComm, size_t size, |
| 33 | + int buffer_type, int value, void **buff, |
| 34 | + void **mr_handle) |
| 35 | +{ |
| 36 | + constexpr uint64_t mrFlags = 0; /* TODO FORCE_SO */ |
| 37 | + OFINCCLCHECK(allocate_buff(buff, size, buffer_type)); |
| 38 | + OFINCCLCHECK(initialize_buff(*buff, size, buffer_type, value)); |
| 39 | + |
| 40 | + void *gin_handle = nullptr; |
| 41 | + OFINCCLCHECK(extGin->regMrSym(collComm, *buff, size, buffer_type, mrFlags, mr_handle, |
| 42 | + &gin_handle)); |
| 43 | + assert(*mr_handle != nullptr && gin_handle != nullptr); |
| 44 | + |
| 45 | + return ncclSuccess; |
| 46 | +} |
| 47 | + |
| 48 | +static inline ncclResult_t verify_buff(int rank, void *buff, int send_val) |
| 49 | +{ |
| 50 | + uint8_t verif_buf[SEND_SIZE]; |
| 51 | + CUDACHECK(cudaMemcpy(verif_buf, buff, SEND_SIZE, cudaMemcpyDefault)); |
| 52 | + for (int i = 0; i < SEND_SIZE; ++i) { |
| 53 | + if (verif_buf[i] != send_val) { |
| 54 | + NCCL_OFI_WARN("Test failed: verif_buf did not have expected value"); |
| 55 | + NCCL_OFI_WARN("Rank %d, Index %d, expected %hu but got %hu", rank, i, |
| 56 | + send_val, verif_buf[i]); |
| 57 | + return ncclSystemError; |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + return ncclSuccess; |
| 62 | +} |
| 63 | + |
| 64 | +int main(int argc, char *argv[]) |
| 65 | +{ |
| 66 | + ncclResult_t res = ncclSuccess; |
| 67 | + int rank, nranks, proc_name_len, local_rank = 0; |
| 68 | + int buffer_type = NCCL_PTR_HOST; |
| 69 | + |
| 70 | + /* Plugin defines */ |
| 71 | + int ndev; |
| 72 | + |
| 73 | + int dev; |
| 74 | + |
| 75 | + /* Start up MPI */ |
| 76 | + MPI_Init(&argc, &argv); |
| 77 | + MPI_Comm_rank(MPI_COMM_WORLD, &rank); |
| 78 | + MPI_Comm_size(MPI_COMM_WORLD, &nranks); |
| 79 | + |
| 80 | + std::vector<char[NCCL_NET_HANDLE_MAXSIZE]> handles(nranks); |
| 81 | + std::vector<void *> handles_ptrs(nranks); |
| 82 | + |
| 83 | + ofi_log_function = logger; |
| 84 | + |
| 85 | + if (nranks < 2) { |
| 86 | + NCCL_OFI_WARN("Expected at least two ranks but got %d. " |
| 87 | + "The gin functional test should be run with at least two ranks.", |
| 88 | + nranks); |
| 89 | + res = ncclInvalidArgument; |
| 90 | + return res; |
| 91 | + } |
| 92 | + |
| 93 | + /* All processors IDs, used to find out the local rank */ |
| 94 | + std::vector<char> all_proc_name(nranks * MPI_MAX_PROCESSOR_NAME); |
| 95 | + |
| 96 | + MPI_Get_processor_name(&all_proc_name[PROC_NAME_IDX(rank)], &proc_name_len); |
| 97 | + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_proc_name.data(), |
| 98 | + MPI_MAX_PROCESSOR_NAME, MPI_BYTE, MPI_COMM_WORLD); |
| 99 | + |
| 100 | + /* Determine local rank */ |
| 101 | + for (int i = 0; i < nranks; i++) { |
| 102 | + if (!strcmp(&all_proc_name[PROC_NAME_IDX(rank)], |
| 103 | + &all_proc_name[PROC_NAME_IDX(i)])) { |
| 104 | + if (i < rank) { |
| 105 | + ++local_rank; |
| 106 | + } |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + /* Set CUDA device for subsequent device memory allocation, in case GDR is used */ |
| 111 | + NCCL_OFI_TRACE(NCCL_NET, "Using CUDA device %d for memory allocation", local_rank); |
| 112 | + CUDACHECK(cudaSetDevice(local_rank)); |
| 113 | + |
| 114 | + /* Get external Network from NCCL-OFI library */ |
| 115 | + set_system_page_size(); |
| 116 | + auto *net_plugin_handle = load_netPlugin(); |
| 117 | + auto *extNet = get_netPlugin_symbol(net_plugin_handle); |
| 118 | + auto *extGin = get_ginPlugin_symbol(net_plugin_handle); |
| 119 | + if (extNet == nullptr || extGin == NULL) { |
| 120 | + res = ncclInternalError; |
| 121 | + return res; |
| 122 | + } |
| 123 | + |
| 124 | + void *netCtx = nullptr; |
| 125 | + ncclNetCommConfig_v11_t netConfig = {}; |
| 126 | + /** |
| 127 | + * Although the net plugin isn't used in this test, the GIN plugin |
| 128 | + * requires the net plugin to be initialized, since they share some of |
| 129 | + * the underlying structures. NCCL will always initialize the net plugin |
| 130 | + * before the GIN plugin, so emulating that behavior here. |
| 131 | + */ |
| 132 | + OFINCCLCHECK(extNet->init(&netCtx, 0, &netConfig, &logger, nullptr)); |
| 133 | + |
| 134 | + void *ginCtx = nullptr; |
| 135 | + |
| 136 | + /* Init API */ |
| 137 | + OFINCCLCHECK(extGin->init(&ginCtx, 0, &logger)); |
| 138 | + NCCL_OFI_INFO(NCCL_NET, "Process rank %d started. NCCL-GIN device used on %s is %s.", rank, |
| 139 | + &all_proc_name[PROC_NAME_IDX(rank)], extGin->name); |
| 140 | + |
| 141 | + /* Devices API */ |
| 142 | + OFINCCLCHECK(extGin->devices(&ndev)); |
| 143 | + NCCL_OFI_INFO(NCCL_NET, "Received %d network devices", ndev); |
| 144 | + |
| 145 | + /* Indicates if NICs support GPUDirect */ |
| 146 | + std::vector<int> test_support_gdr(ndev); |
| 147 | + |
| 148 | + /* Get Properties for the device */ |
| 149 | + for (dev = 0; dev < ndev; dev++) { |
| 150 | + ncclNetProperties_v11_t props = {}; |
| 151 | + OFINCCLCHECK(extGin->getProperties(dev, &props)); |
| 152 | + |
| 153 | + /* Set CUDA support */ |
| 154 | + test_support_gdr[dev] = is_gdr_supported_nic(props.ptrSupport); |
| 155 | + } |
| 156 | + |
| 157 | + dev = local_rank % ndev; |
| 158 | + |
| 159 | + NCCL_OFI_TRACE(NCCL_INIT, "Rank %d uses %d device for communication", rank, dev); |
| 160 | + |
| 161 | + if (test_support_gdr[dev] == 1) { |
| 162 | + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, |
| 163 | + "Network supports communication using CUDA buffers. Dev: %d", dev); |
| 164 | + buffer_type = NCCL_PTR_CUDA; |
| 165 | + } else { |
| 166 | + /* We aren't currently interested in the non-GDR use case for GIN */ |
| 167 | + NCCL_OFI_WARN("Network does not support communication using CUDA buffers. Dev: %d", |
| 168 | + dev); |
| 169 | + return 1; |
| 170 | + } |
| 171 | + |
| 172 | + void *listenComm = nullptr; |
| 173 | + OFINCCLCHECK(extGin->listen(ginCtx, dev, handles[rank], &listenComm)); |
| 174 | + assert(listenComm); |
| 175 | + |
| 176 | + /* Gather handles from all ranks */ |
| 177 | + MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, handles.data(), NCCL_NET_HANDLE_MAXSIZE, |
| 178 | + MPI_CHAR, MPI_COMM_WORLD); |
| 179 | + |
| 180 | + /* Prepare handles void** array */ |
| 181 | + for (int i = 0; i < nranks; ++i) { |
| 182 | + handles_ptrs[i] = &(handles[i]); |
| 183 | + } |
| 184 | + |
| 185 | + void *collComm = nullptr; |
| 186 | + OFINCCLCHECK( |
| 187 | + extGin->connect(ginCtx, handles_ptrs.data(), nranks, rank, listenComm, &collComm)); |
| 188 | + assert(collComm != nullptr); |
| 189 | + |
| 190 | + /* Allocate, register, and initialize all buffers to zero. */ |
| 191 | + void *put_buff = nullptr; |
| 192 | + void *put_mhandle = nullptr; |
| 193 | + OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, SEND_SIZE, buffer_type, 0, &put_buff, |
| 194 | + &put_mhandle)); |
| 195 | + |
| 196 | + void *put_signal_buff = nullptr; |
| 197 | + void *put_signal_mhandle = nullptr; |
| 198 | + OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, SEND_SIZE, buffer_type, 0, |
| 199 | + &put_signal_buff, &put_signal_mhandle)); |
| 200 | + |
| 201 | + void *signal_buf = nullptr; |
| 202 | + void *signal_mhandle = nullptr; |
| 203 | + OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, sizeof(uint64_t), buffer_type, 0, |
| 204 | + &signal_buf, &signal_mhandle)); |
| 205 | + |
| 206 | + const int send_val = 42; /* arbitrary */ |
| 207 | + const int NUM_REQS_PER_PEER = 64; |
| 208 | + assert((NUM_REQS_PER_PEER * 2) <= NCCL_OFI_MAX_REQUESTS); |
| 209 | + |
| 210 | + if (rank == 0) { |
| 211 | + OFINCCLCHECK(initialize_buff(put_buff, SEND_SIZE, buffer_type, send_val)); |
| 212 | + OFINCCLCHECK(initialize_buff(put_signal_buff, SEND_SIZE, buffer_type, send_val)); |
| 213 | + |
| 214 | + std::deque<void *> request_deque; |
| 215 | + |
| 216 | + for (int dst_rank = 1; dst_rank < nranks; ++dst_rank) { |
| 217 | + /* iput API */ |
| 218 | + for (int i = 0; i < NUM_REQS_PER_PEER; ++i) { |
| 219 | + void *request = nullptr; |
| 220 | + OFINCCLCHECK(extGin->iput(collComm, 0, put_mhandle, SEND_SIZE, 0, |
| 221 | + put_mhandle, dst_rank, &request)); |
| 222 | + assert(request != nullptr); |
| 223 | + request_deque.push_back(request); |
| 224 | + } |
| 225 | + |
| 226 | + /* iputSignal API */ |
| 227 | + for (int i = 0; i < NUM_REQS_PER_PEER; ++i) { |
| 228 | + /* TODO: Expand the test to cover other signal types, such as |
| 229 | + * NCCL_NET_SIGNAL_OP_ADD */ |
| 230 | + void *request = nullptr; |
| 231 | + OFINCCLCHECK(extGin->iputSignal(collComm, 0, put_signal_mhandle, |
| 232 | + SEND_SIZE, 0, put_signal_mhandle, |
| 233 | + dst_rank, 0, signal_mhandle, 1, |
| 234 | + NCCL_NET_SIGNAL_OP_INC, &request)); |
| 235 | + assert(request != nullptr); |
| 236 | + request_deque.push_back(request); |
| 237 | + } |
| 238 | + } |
| 239 | + |
| 240 | + /* Wait for remaining requests */ |
| 241 | + while (!request_deque.empty()) { |
| 242 | + OFINCCLCHECK(poll_request_completion(extGin, request_deque, collComm)); |
| 243 | + } |
| 244 | + } else { |
| 245 | + /* Validate that the signal_buff reaches the designated signal value */ |
| 246 | + uint64_t signal_h = 0; |
| 247 | + while (signal_h != NUM_REQS_PER_PEER) { |
| 248 | + OFINCCLCHECK(extGin->ginProgress(collComm)); |
| 249 | + CUDACHECK(cudaMemcpy(&signal_h, signal_buf, sizeof(uint64_t), |
| 250 | + cudaMemcpyDefault)); |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + MPI_Request barrier_req; |
| 255 | + MPI_Ibarrier(MPI_COMM_WORLD, &barrier_req); |
| 256 | + int barrier_done = 0; |
| 257 | + while (!barrier_done) { |
| 258 | + /* Make progress on comm until all ranks reach the barrier */ |
| 259 | + OFINCCLCHECK(extGin->ginProgress(collComm)); |
| 260 | + MPI_Test(&barrier_req, &barrier_done, MPI_STATUS_IGNORE); |
| 261 | + } |
| 262 | + |
| 263 | + /* Verification */ |
| 264 | + NCCL_OFI_INFO(NCCL_NET, "Verifying result.."); |
| 265 | + OFINCCLCHECK(verify_buff(rank, put_buff, send_val)); |
| 266 | + OFINCCLCHECK(verify_buff(rank, put_signal_buff, send_val)); |
| 267 | + |
| 268 | + /* Cleanup APIs */ |
| 269 | + OFINCCLCHECK(extGin->deregMrSym(collComm, signal_mhandle)); |
| 270 | + signal_mhandle = nullptr; |
| 271 | + OFINCCLCHECK(extGin->deregMrSym(collComm, put_signal_mhandle)); |
| 272 | + put_signal_mhandle = nullptr; |
| 273 | + OFINCCLCHECK(extGin->deregMrSym(collComm, put_mhandle)); |
| 274 | + put_mhandle = nullptr; |
| 275 | + |
| 276 | + OFINCCLCHECK(extGin->closeColl(collComm)); |
| 277 | + collComm = nullptr; |
| 278 | + OFINCCLCHECK(extGin->closeListen(listenComm)); |
| 279 | + listenComm = nullptr; |
| 280 | + |
| 281 | + OFINCCLCHECK(extGin->finalize(ginCtx)); |
| 282 | + OFINCCLCHECK(extNet->finalize(netCtx)); |
| 283 | + |
| 284 | + dlclose(net_plugin_handle); |
| 285 | + |
| 286 | + MPI_Barrier(MPI_COMM_WORLD); |
| 287 | + MPI_Finalize(); |
| 288 | + |
| 289 | + /* Clean up local resources */ |
| 290 | + OFINCCLCHECK(deallocate_buffer(signal_buf, buffer_type)); |
| 291 | + signal_buf = nullptr; |
| 292 | + OFINCCLCHECK(deallocate_buffer(put_signal_buff, buffer_type)); |
| 293 | + put_signal_buff = nullptr; |
| 294 | + OFINCCLCHECK(deallocate_buffer(put_buff, buffer_type)); |
| 295 | + put_buff = nullptr; |
| 296 | + |
| 297 | + NCCL_OFI_INFO(NCCL_NET, "Test completed successfully for rank %d", rank); |
| 298 | + |
| 299 | + return res; |
| 300 | +} |
0 commit comments