Skip to content

Commit 2c8fdb6

Browse files
rautericmozarhua
authored andcommitted
tests: Add GIN functional test
Signed-off-by: Eric Raut <eraut@amazon.com>
1 parent 4e1f26d commit 2c8fdb6

File tree

4 files changed

+320
-4
lines changed

4 files changed

+320
-4
lines changed

tests/functional/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ nccl_connection
33
nccl_message_transfer
44
reuse_listen_comm
55
ring
6+
gin

tests/functional/Makefile.am

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ CXXLINK = OMPI_CXX="$(CXX)" MPICH_CXX="$(CXX)" \
3333
if ENABLE_FUNC_TESTS
3434
noinst_HEADERS = test-common.h
3535

36-
bin_PROGRAMS = nccl_connection nccl_message_transfer ring inflight_close reuse_listen_comm
36+
bin_PROGRAMS = nccl_connection nccl_message_transfer ring inflight_close reuse_listen_comm gin
3737

3838
nccl_connection_SOURCES = nccl_connection.cpp
3939
nccl_message_transfer_SOURCES = nccl_message_transfer.cpp
4040
ring_SOURCES = ring.cpp
4141
inflight_close_SOURCES = inflight_close.cpp
4242
reuse_listen_comm_SOURCES = reuse_listen_comm.cpp
43+
gin_SOURCES = gin.cpp
4344
endif

tests/functional/gin.cpp

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
/*
2+
* Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All rights reserved.
3+
*/
4+
5+
#include "config.h"
6+
7+
#include "test-common.h"
8+
9+
#include <deque>
10+
#include <vector>
11+
12+
static inline ncclResult_t
13+
poll_request_completion(ncclGin_v11_t *extGin, std::deque<void *> &request_deque, void *collComm)
14+
{
15+
/* Wait for outstanding requests */
16+
int done = 0;
17+
OFINCCLCHECK(extGin->test(collComm, request_deque.front(), &done));
18+
if (done) {
19+
request_deque.pop_front();
20+
} else {
21+
/**
22+
* Note: The NCCL GIN proxy thread will call ginProgress repeatedly
23+
* until the communicator is closed. We emulate that throughout this
24+
* test by ensuring each rank calls `ginProgress` until the barrier
25+
* before verification is reached.
26+
*/
27+
OFINCCLCHECK(extGin->ginProgress(collComm));
28+
}
29+
return ncclSuccess;
30+
}
31+
32+
static inline ncclResult_t alloc_and_reg_buff(ncclGin_v11_t *extGin, void *collComm, size_t size,
33+
int buffer_type, int value, void **buff,
34+
void **mr_handle)
35+
{
36+
constexpr uint64_t mrFlags = 0; /* TODO FORCE_SO */
37+
OFINCCLCHECK(allocate_buff(buff, size, buffer_type));
38+
OFINCCLCHECK(initialize_buff(*buff, size, buffer_type, value));
39+
40+
void *gin_handle = nullptr;
41+
OFINCCLCHECK(extGin->regMrSym(collComm, *buff, size, buffer_type, mrFlags, mr_handle,
42+
&gin_handle));
43+
assert(*mr_handle != nullptr && gin_handle != nullptr);
44+
45+
return ncclSuccess;
46+
}
47+
48+
static inline ncclResult_t verify_buff(int rank, void *buff, int send_val)
49+
{
50+
uint8_t verif_buf[SEND_SIZE];
51+
CUDACHECK(cudaMemcpy(verif_buf, buff, SEND_SIZE, cudaMemcpyDefault));
52+
for (int i = 0; i < SEND_SIZE; ++i) {
53+
if (verif_buf[i] != send_val) {
54+
NCCL_OFI_WARN("Test failed: verif_buf did not have expected value");
55+
NCCL_OFI_WARN("Rank %d, Index %d, expected %hu but got %hu", rank, i,
56+
send_val, verif_buf[i]);
57+
return ncclSystemError;
58+
}
59+
}
60+
61+
return ncclSuccess;
62+
}
63+
64+
int main(int argc, char *argv[])
65+
{
66+
ncclResult_t res = ncclSuccess;
67+
int rank, nranks, proc_name_len, local_rank = 0;
68+
int buffer_type = NCCL_PTR_HOST;
69+
70+
/* Plugin defines */
71+
int ndev;
72+
73+
int dev;
74+
75+
/* Start up MPI */
76+
MPI_Init(&argc, &argv);
77+
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
78+
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
79+
80+
std::vector<char[NCCL_NET_HANDLE_MAXSIZE]> handles(nranks);
81+
std::vector<void *> handles_ptrs(nranks);
82+
83+
ofi_log_function = logger;
84+
85+
if (nranks < 2) {
86+
NCCL_OFI_WARN("Expected at least two ranks but got %d. "
87+
"The gin functional test should be run with at least two ranks.",
88+
nranks);
89+
res = ncclInvalidArgument;
90+
return res;
91+
}
92+
93+
/* All processors IDs, used to find out the local rank */
94+
std::vector<char> all_proc_name(nranks * MPI_MAX_PROCESSOR_NAME);
95+
96+
MPI_Get_processor_name(&all_proc_name[PROC_NAME_IDX(rank)], &proc_name_len);
97+
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, all_proc_name.data(),
98+
MPI_MAX_PROCESSOR_NAME, MPI_BYTE, MPI_COMM_WORLD);
99+
100+
/* Determine local rank */
101+
for (int i = 0; i < nranks; i++) {
102+
if (!strcmp(&all_proc_name[PROC_NAME_IDX(rank)],
103+
&all_proc_name[PROC_NAME_IDX(i)])) {
104+
if (i < rank) {
105+
++local_rank;
106+
}
107+
}
108+
}
109+
110+
/* Set CUDA device for subsequent device memory allocation, in case GDR is used */
111+
NCCL_OFI_TRACE(NCCL_NET, "Using CUDA device %d for memory allocation", local_rank);
112+
CUDACHECK(cudaSetDevice(local_rank));
113+
114+
/* Get external Network from NCCL-OFI library */
115+
set_system_page_size();
116+
auto *net_plugin_handle = load_netPlugin();
117+
auto *extNet = get_netPlugin_symbol(net_plugin_handle);
118+
auto *extGin = get_ginPlugin_symbol(net_plugin_handle);
119+
if (extNet == nullptr || extGin == NULL) {
120+
res = ncclInternalError;
121+
return res;
122+
}
123+
124+
void *netCtx = nullptr;
125+
ncclNetCommConfig_v11_t netConfig = {};
126+
/**
127+
* Although the net plugin isn't used in this test, the GIN plugin
128+
* requires the net plugin to be initialized, since they share some of
129+
* the underlying structures. NCCL will always initialize the net plugin
130+
* before the GIN plugin, so emulating that behavior here.
131+
*/
132+
OFINCCLCHECK(extNet->init(&netCtx, 0, &netConfig, &logger, nullptr));
133+
134+
void *ginCtx = nullptr;
135+
136+
/* Init API */
137+
OFINCCLCHECK(extGin->init(&ginCtx, 0, &logger));
138+
NCCL_OFI_INFO(NCCL_NET, "Process rank %d started. NCCL-GIN device used on %s is %s.", rank,
139+
&all_proc_name[PROC_NAME_IDX(rank)], extGin->name);
140+
141+
/* Devices API */
142+
OFINCCLCHECK(extGin->devices(&ndev));
143+
NCCL_OFI_INFO(NCCL_NET, "Received %d network devices", ndev);
144+
145+
/* Indicates if NICs support GPUDirect */
146+
std::vector<int> test_support_gdr(ndev);
147+
148+
/* Get Properties for the device */
149+
for (dev = 0; dev < ndev; dev++) {
150+
ncclNetProperties_v11_t props = {};
151+
OFINCCLCHECK(extGin->getProperties(dev, &props));
152+
153+
/* Set CUDA support */
154+
test_support_gdr[dev] = is_gdr_supported_nic(props.ptrSupport);
155+
}
156+
157+
dev = local_rank % ndev;
158+
159+
NCCL_OFI_TRACE(NCCL_INIT, "Rank %d uses %d device for communication", rank, dev);
160+
161+
if (test_support_gdr[dev] == 1) {
162+
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET,
163+
"Network supports communication using CUDA buffers. Dev: %d", dev);
164+
buffer_type = NCCL_PTR_CUDA;
165+
} else {
166+
/* We aren't currently interested in the non-GDR use case for GIN */
167+
NCCL_OFI_WARN("Network does not support communication using CUDA buffers. Dev: %d",
168+
dev);
169+
return 1;
170+
}
171+
172+
void *listenComm = nullptr;
173+
OFINCCLCHECK(extGin->listen(ginCtx, dev, handles[rank], &listenComm));
174+
assert(listenComm);
175+
176+
/* Gather handles from all ranks */
177+
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, handles.data(), NCCL_NET_HANDLE_MAXSIZE,
178+
MPI_CHAR, MPI_COMM_WORLD);
179+
180+
/* Prepare handles void** array */
181+
for (int i = 0; i < nranks; ++i) {
182+
handles_ptrs[i] = &(handles[i]);
183+
}
184+
185+
void *collComm = nullptr;
186+
OFINCCLCHECK(
187+
extGin->connect(ginCtx, handles_ptrs.data(), nranks, rank, listenComm, &collComm));
188+
assert(collComm != nullptr);
189+
190+
/* Allocate, register, and initialize all buffers to zero. */
191+
void *put_buff = nullptr;
192+
void *put_mhandle = nullptr;
193+
OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, SEND_SIZE, buffer_type, 0, &put_buff,
194+
&put_mhandle));
195+
196+
void *put_signal_buff = nullptr;
197+
void *put_signal_mhandle = nullptr;
198+
OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, SEND_SIZE, buffer_type, 0,
199+
&put_signal_buff, &put_signal_mhandle));
200+
201+
void *signal_buf = nullptr;
202+
void *signal_mhandle = nullptr;
203+
OFINCCLCHECK(alloc_and_reg_buff(extGin, collComm, sizeof(uint64_t), buffer_type, 0,
204+
&signal_buf, &signal_mhandle));
205+
206+
const int send_val = 42; /* arbitrary */
207+
const int NUM_REQS_PER_PEER = 64;
208+
assert((NUM_REQS_PER_PEER * 2) <= NCCL_OFI_MAX_REQUESTS);
209+
210+
if (rank == 0) {
211+
OFINCCLCHECK(initialize_buff(put_buff, SEND_SIZE, buffer_type, send_val));
212+
OFINCCLCHECK(initialize_buff(put_signal_buff, SEND_SIZE, buffer_type, send_val));
213+
214+
std::deque<void *> request_deque;
215+
216+
for (int dst_rank = 1; dst_rank < nranks; ++dst_rank) {
217+
/* iput API */
218+
for (int i = 0; i < NUM_REQS_PER_PEER; ++i) {
219+
void *request = nullptr;
220+
OFINCCLCHECK(extGin->iput(collComm, 0, put_mhandle, SEND_SIZE, 0,
221+
put_mhandle, dst_rank, &request));
222+
assert(request != nullptr);
223+
request_deque.push_back(request);
224+
}
225+
226+
/* iputSignal API */
227+
for (int i = 0; i < NUM_REQS_PER_PEER; ++i) {
228+
/* TODO: Expand the test to cover other signal types, such as
229+
* NCCL_NET_SIGNAL_OP_ADD */
230+
void *request = nullptr;
231+
OFINCCLCHECK(extGin->iputSignal(collComm, 0, put_signal_mhandle,
232+
SEND_SIZE, 0, put_signal_mhandle,
233+
dst_rank, 0, signal_mhandle, 1,
234+
NCCL_NET_SIGNAL_OP_INC, &request));
235+
assert(request != nullptr);
236+
request_deque.push_back(request);
237+
}
238+
}
239+
240+
/* Wait for remaining requests */
241+
while (!request_deque.empty()) {
242+
OFINCCLCHECK(poll_request_completion(extGin, request_deque, collComm));
243+
}
244+
} else {
245+
/* Validate that the signal_buff reaches the designated signal value */
246+
uint64_t signal_h = 0;
247+
while (signal_h != NUM_REQS_PER_PEER) {
248+
OFINCCLCHECK(extGin->ginProgress(collComm));
249+
CUDACHECK(cudaMemcpy(&signal_h, signal_buf, sizeof(uint64_t),
250+
cudaMemcpyDefault));
251+
}
252+
}
253+
254+
MPI_Request barrier_req;
255+
MPI_Ibarrier(MPI_COMM_WORLD, &barrier_req);
256+
int barrier_done = 0;
257+
while (!barrier_done) {
258+
/* Make progress on comm until all ranks reach the barrier */
259+
OFINCCLCHECK(extGin->ginProgress(collComm));
260+
MPI_Test(&barrier_req, &barrier_done, MPI_STATUS_IGNORE);
261+
}
262+
263+
/* Verification */
264+
NCCL_OFI_INFO(NCCL_NET, "Verifying result..");
265+
OFINCCLCHECK(verify_buff(rank, put_buff, send_val));
266+
OFINCCLCHECK(verify_buff(rank, put_signal_buff, send_val));
267+
268+
/* Cleanup APIs */
269+
OFINCCLCHECK(extGin->deregMrSym(collComm, signal_mhandle));
270+
signal_mhandle = nullptr;
271+
OFINCCLCHECK(extGin->deregMrSym(collComm, put_signal_mhandle));
272+
put_signal_mhandle = nullptr;
273+
OFINCCLCHECK(extGin->deregMrSym(collComm, put_mhandle));
274+
put_mhandle = nullptr;
275+
276+
OFINCCLCHECK(extGin->closeColl(collComm));
277+
collComm = nullptr;
278+
OFINCCLCHECK(extGin->closeListen(listenComm));
279+
listenComm = nullptr;
280+
281+
OFINCCLCHECK(extGin->finalize(ginCtx));
282+
OFINCCLCHECK(extNet->finalize(netCtx));
283+
284+
dlclose(net_plugin_handle);
285+
286+
MPI_Barrier(MPI_COMM_WORLD);
287+
MPI_Finalize();
288+
289+
/* Clean up local resources */
290+
OFINCCLCHECK(deallocate_buffer(signal_buf, buffer_type));
291+
signal_buf = nullptr;
292+
OFINCCLCHECK(deallocate_buffer(put_signal_buff, buffer_type));
293+
put_signal_buff = nullptr;
294+
OFINCCLCHECK(deallocate_buffer(put_buff, buffer_type));
295+
put_buff = nullptr;
296+
297+
NCCL_OFI_INFO(NCCL_NET, "Test completed successfully for rank %d", rank);
298+
299+
return res;
300+
}

tests/functional/test-common.h

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@
8484

8585
// Can be changed when porting new versions to the plugin
8686
#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v11
87+
#define NCCL_GIN_PLUGIN_SYMBOL ncclGinPlugin_v11
8788

8889
typedef ncclNet_v11_t test_nccl_net_t;
90+
typedef ncclGin_v11_t test_nccl_gin_t;
8991
typedef ncclNetProperties_v11_t test_nccl_properties_t;
9092
typedef ncclNetDeviceHandle_v11_t test_nccl_net_device_handle_t;
9193
typedef ncclNetCommConfig_v11_t test_nccl_net_config_t;
@@ -181,16 +183,18 @@ static inline ncclResult_t allocate_buff(void **buf, size_t size, int buffer_typ
181183
* @param buf Buffer to initialize
182184
* @param size Size of buffer in bytes
183185
* @param buffer_type NCCL_PTR_HOST for host memory, NCCL_PTR_CUDA for device memory
186+
* @param value value for each element, default '1'
184187
* @return ncclSuccess on success, error code otherwise
185188
*/
186-
static inline ncclResult_t initialize_buff(void *buf, size_t size, int buffer_type)
189+
static inline ncclResult_t initialize_buff(void *buf, size_t size, int buffer_type, int value='1')
187190
{
188191
switch (buffer_type) {
189192
case NCCL_PTR_CUDA:
190-
CUDACHECK(cudaMemset(buf, '1', size));
193+
CUDACHECK(cudaMemset(buf, value, size));
194+
CUDACHECK(cudaStreamSynchronize(cudaStreamDefault));
191195
break;
192196
case NCCL_PTR_HOST:
193-
memset(buf, '1', size);
197+
memset(buf, value, size);
194198
break;
195199
default:
196200
NCCL_OFI_WARN("Unidentified buffer type: %d", buffer_type);
@@ -455,6 +459,16 @@ static inline test_nccl_net_t *get_netPlugin_symbol(void *netPluginLib)
455459
return extNet;
456460
}
457461

462+
static inline test_nccl_gin_t *get_ginPlugin_symbol(void *netPluginLib)
463+
{
464+
test_nccl_gin_t *extGin = (test_nccl_gin_t *)dlsym(netPluginLib, STR(NCCL_GIN_PLUGIN_SYMBOL));
465+
if (extGin == NULL) {
466+
NCCL_OFI_WARN("GinPlugin, could not find %s symbol",
467+
STR(NCCL_GIN_PLUGIN_SYMBOL));
468+
}
469+
return extGin;
470+
}
471+
458472
static inline test_nccl_net_t *get_extNet(void)
459473
{
460474
set_system_page_size();

0 commit comments

Comments
 (0)