Skip to content

Commit ca89a1a

Browse files
committed
Disable CTS Offload and CTS Inline when NIC Fusion is enabled in net_ib_rocm
1 parent 8839b04 commit ca89a1a

File tree

2 files changed

+62
-34
lines changed

2 files changed

+62
-34
lines changed

projects/rccl/ext-src/rocm_netib.patch

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ index 9bfd8dcf..4d3f0a08 100644
7474

7575
// Detect IB cards
7676
int nIbDevs = 0;
77-
@@ -944,6 +978,23 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config
77+
@@ -944,6 +978,37 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config
7878
INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "",
7979
ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline));
8080

@@ -86,6 +86,20 @@ index 9bfd8dcf..4d3f0a08 100644
8686
+ // for AINIC, these params are defaulted to enabled unless user forces it to disable(0).
8787
+ rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true);
8888
+ rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true);
89+
+
90+
+ // CTS Offload and CTS Inline are not yet compatible with NIC Fusion
91+
+ // (NCCL_IB_MERGE_NICS). Temporarily disable them when merge is enabled.
92+
+ if (ncclParamRocmIbMergeNics()) {
93+
+ if (rcclCtsInlineData) {
94+
+ INFO(NCCL_INIT|NCCL_NET, "NET/IB : NIC Fusion enabled - disabling CTS Inline Data (not yet supported with merge)");
95+
+ rcclCtsInlineData = false;
96+
+ }
97+
+ if (rcclCtsOffloadEnabled) {
98+
+ INFO(NCCL_INIT|NCCL_NET, "NET/IB : NIC Fusion enabled - disabling CTS Offload (not yet supported with merge)");
99+
+ rcclCtsOffloadEnabled = false;
100+
+ }
101+
+ }
102+
+
89103
+ // for AINIC IbUseInline is enabled by default always
90104
+ ncclIbUseInline = true;
91105
+ // for AINIC GDR flush is disabled by default
@@ -98,7 +112,7 @@ index 9bfd8dcf..4d3f0a08 100644
98112
}
99113
exit:
100114
ibContext.trafficClass = config->trafficClass;
101-
@@ -1271,6 +1322,8 @@ struct ncclIbListenComm {
115+
@@ -1271,6 +1336,8 @@ struct ncclIbListenComm {
102116
struct ncclIbCommStage stage;
103117
};
104118

@@ -107,7 +121,7 @@ index 9bfd8dcf..4d3f0a08 100644
107121
struct alignas(64) ncclIbSendFifo {
108122
uint64_t addr;
109123
uint64_t size;
110-
@@ -1281,10 +1334,21 @@ struct alignas(64) ncclIbSendFifo {
124+
@@ -1281,10 +1348,21 @@ struct alignas(64) ncclIbSendFifo {
111125
char padding[16];
112126
};
113127

@@ -129,31 +143,31 @@ index 9bfd8dcf..4d3f0a08 100644
129143
};
130144

131145
struct ncclIbRemSizesFifo {
132-
@@ -1331,6 +1395,7 @@ struct ncclIbSendComm {
146+
@@ -1331,6 +1409,7 @@ struct ncclIbSendComm {
133147
struct ncclIbNetCommBase base;
134148
// Start with fifo and ibv structs as they have alignment restrictions
135149
struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
136150
+ struct ncclIbSendFifoCtsInline fifo_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
137151
struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS];
138152
struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS + 1];
139153
// Each dev correlates to a mergedIbDev
140-
@@ -1346,6 +1411,7 @@ struct ncclIbSendComm {
154+
@@ -1346,6 +1425,7 @@ struct ncclIbSendComm {
141155
static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset");
142156
static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned");
143157
static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples");
144158
+static_assert((sizeof(struct ncclIbSendFifoCtsInline) % 32) == 0, "ncclIbSendFifoCtsInline element size must be 32-byte multiples");
145159
static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned");
146160
static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned");
147161

148-
@@ -1360,6 +1426,7 @@ struct ncclIbGpuFlush {
162+
@@ -1360,6 +1440,7 @@ struct ncclIbGpuFlush {
149163

150164
struct ncclIbRemFifo {
151165
struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
152166
+ struct ncclIbSendFifoCtsInline elems_cts_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS];
153167
uint64_t fifoTail;
154168
uint64_t addr;
155169
uint32_t flags;
156-
@@ -1415,20 +1482,59 @@ ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) {
170+
@@ -1415,20 +1496,59 @@ ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) {
157171
return ncclSuccess;
158172
}
159173

@@ -215,7 +229,7 @@ index 9bfd8dcf..4d3f0a08 100644
215229
struct ibv_qp_attr qpAttr;
216230
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
217231
qpAttr.qp_state = IBV_QPS_INIT;
218-
@@ -1438,6 +1544,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
232+
@@ -1438,6 +1558,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
219233
NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS));
220234
TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p",
221235
ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd);
@@ -225,7 +239,7 @@ index 9bfd8dcf..4d3f0a08 100644
225239
return ncclSuccess;
226240
}
227241

228-
@@ -1521,7 +1630,7 @@ fail:
242+
@@ -1521,7 +1644,7 @@ fail:
229243
goto exit;
230244
}
231245

@@ -234,7 +248,7 @@ index 9bfd8dcf..4d3f0a08 100644
234248
ncclResult_t ret = ncclSuccess;
235249
struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle;
236250
struct ncclIbCommStage* stage = &handle->stage;
237-
@@ -1529,8 +1638,13 @@ ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendCo
251+
@@ -1529,8 +1652,13 @@ ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendCo
238252
int ready;
239253
uint8_t link_layer = IBV_LINK_LAYER_UNSPECIFIED;
240254
int isP2p = 0;
@@ -248,7 +262,7 @@ index 9bfd8dcf..4d3f0a08 100644
248262
if (stage->state == ncclIbCommStateConnect) goto ib_connect_check;
249263
if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list;
250264
if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list;
251-
@@ -1612,7 +1726,7 @@ ib_recv_dev_list:
265+
@@ -1612,7 +1740,7 @@ ib_recv_dev_list:
252266
for (int q = 0; q < comm->base.nqps; q++) {
253267
ncclIbSendCommDev* commDev = comm->devs + devIndex;
254268
ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN;
@@ -257,7 +271,7 @@ index 9bfd8dcf..4d3f0a08 100644
257271
comm->base.qps[q].devIndex = devIndex;
258272
meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num;
259273
meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex;
260-
@@ -1637,7 +1751,11 @@ ib_recv_dev_list:
274+
@@ -1637,7 +1765,11 @@ ib_recv_dev_list:
261275
devInfo->lid = ibDev->portAttr.lid;
262276
devInfo->ibv_dev_index = commDev->base.ibDevN;
263277
// Prepare my fifo
@@ -270,7 +284,7 @@ index 9bfd8dcf..4d3f0a08 100644
270284
devInfo->fifoRkey = commDev->fifoMr->rkey;
271285

272286
// Pack local GID info
273-
@@ -1680,7 +1798,11 @@ ib_recv_dev_list:
287+
@@ -1680,7 +1812,11 @@ ib_recv_dev_list:
274288
}
275289
}
276290
config = (ncclNetCommConfig_t*)ctx;
@@ -283,7 +297,7 @@ index 9bfd8dcf..4d3f0a08 100644
283297
meta.sl = (ncclParamIbSl() != -1) ? ncclParamIbSl() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_SL_DEFAULT;
284298
meta.tc = (ncclParamIbTc() != -1) ? ncclParamIbTc() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_TC_DEFAULT;
285299
strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME);
286-
@@ -1825,18 +1947,22 @@ ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDevicePro
300+
@@ -1825,18 +1961,22 @@ ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDevicePro
287301
return ncclSuccess;
288302
}
289303

@@ -308,7 +322,7 @@ index 9bfd8dcf..4d3f0a08 100644
308322
if (stage->state == ncclIbCommStateAccept) goto ib_accept_check;
309323
if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list;
310324
if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list;
311-
@@ -1966,7 +2092,7 @@ ib_recv:
325+
@@ -1966,7 +2106,7 @@ ib_recv:
312326
// Local ibDevN
313327
ibDevN = rComm->devs[devIndex].base.ibDevN;
314328
ibDev = ncclIbDevs + ibDevN;
@@ -317,7 +331,7 @@ index 9bfd8dcf..4d3f0a08 100644
317331
qp->devIndex = devIndex;
318332
devIndex = (devIndex + 1) % rComm->base.vProps.ndevs;
319333

320-
@@ -1992,16 +2118,22 @@ ib_recv:
334+
@@ -1992,16 +2132,22 @@ ib_recv:
321335

322336
useDmaBuf = (ncclIbDmaBufSupport(lComm->dev) == ncclSuccess);
323337
rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || useDmaBuf)
@@ -343,7 +357,7 @@ index 9bfd8dcf..4d3f0a08 100644
343357

344358
// Allocate Flush dummy buffer for GPU Direct RDMA
345359
if (rComm->flushEnabled) {
346-
@@ -2039,7 +2171,7 @@ ib_recv:
360+
@@ -2039,7 +2185,7 @@ ib_recv:
347361
rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem;
348362
rCommDev->gpuFlush.sge.length = 1;
349363
rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey;
@@ -352,7 +366,7 @@ index 9bfd8dcf..4d3f0a08 100644
352366
struct ncclIbDevInfo devInfo;
353367
devInfo.lid = ibDev->portAttr.lid;
354368
devInfo.link_layer = ibDev->portAttr.link_layer;
355-
@@ -2257,10 +2389,15 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) {
369+
@@ -2257,10 +2403,15 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) {
356370

357371
NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0);
358372

@@ -370,7 +384,7 @@ index 9bfd8dcf..4d3f0a08 100644
370384
if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError;
371385

372386
uint64_t wr_id = 0ULL;
373-
@@ -2272,7 +2409,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
387+
@@ -2272,7 +2423,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
374388
sge->addr=(uintptr_t)reqs[r]->send.data;
375389
wr->opcode = IBV_WR_RDMA_WRITE;
376390
wr->send_flags = 0;
@@ -383,7 +397,7 @@ index 9bfd8dcf..4d3f0a08 100644
383397
wr->next = wr + 1;
384398
wr_id += (reqs[r] - comm->base.reqs) << (r*8);
385399
#ifdef NCCL_ENABLE_NET_PROFILING
386-
@@ -2283,7 +2424,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
400+
@@ -2283,7 +2438,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
387401
// Write size as immediate data. In the case of multi-send, only write
388402
// 0 or 1 as size to indicate whether there was data sent or received.
389403
uint32_t immData = 0;
@@ -392,7 +406,7 @@ index 9bfd8dcf..4d3f0a08 100644
392406
immData = reqs[0]->send.size;
393407
} else {
394408
int* sizes = comm->remSizesFifo.elems[slot];
395-
@@ -2293,22 +2434,24 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
409+
@@ -2293,22 +2448,24 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
396410
}
397411

398412
struct ibv_send_wr* lastWr = comm->wrs+nreqs-1;
@@ -430,7 +444,7 @@ index 9bfd8dcf..4d3f0a08 100644
430444
lastWr->next = NULL;
431445
lastWr->send_flags = IBV_SEND_SIGNALED;
432446

433-
@@ -2324,7 +2467,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
447+
@@ -2324,7 +2481,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
434448
//ncclIbAddEvent(reqs[r], devIndex, &comm->devs[devIndex].base);
435449

436450
// Select proper rkey (needed even for 0-size send)
@@ -443,7 +457,7 @@ index 9bfd8dcf..4d3f0a08 100644
443457

444458
int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align;
445459
int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize);
446-
@@ -2340,7 +2487,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
460+
@@ -2340,7 +2501,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) {
447461
}
448462
}
449463

@@ -452,7 +466,7 @@ index 9bfd8dcf..4d3f0a08 100644
452466
// Also make sure lastWr writes remote sizes using the right lkey
453467
comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey;
454468
lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex];
455-
@@ -2398,32 +2545,46 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
469+
@@ -2398,32 +2559,46 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
456470
NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__));
457471

458472
struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle;
@@ -517,7 +531,7 @@ index 9bfd8dcf..4d3f0a08 100644
517531
}
518532

519533
struct ncclIbRequest* req;
520-
@@ -2467,10 +2628,12 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
534+
@@ -2467,10 +2642,12 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
521535
}
522536

523537
TIME_START(0);
@@ -532,7 +546,7 @@ index 9bfd8dcf..4d3f0a08 100644
532546
memset(reqs, 0, NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbRequest*));
533547
comm->fifoHead++;
534548
TIME_STOP(0);
535-
@@ -2483,30 +2646,60 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
549+
@@ -2483,30 +2660,60 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void*
536550

537551
ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) {
538552
struct ibv_send_wr wr;
@@ -606,7 +620,7 @@ index 9bfd8dcf..4d3f0a08 100644
606620
}
607621
wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo);
608622

609-
@@ -2514,8 +2707,12 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
623+
@@ -2514,8 +2721,12 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
610624
wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey;
611625

612626
// Set the correct sge properties
@@ -621,7 +635,7 @@ index 9bfd8dcf..4d3f0a08 100644
621635
wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge;
622636
wr.num_sge = 1;
623637

624-
@@ -2545,7 +2742,13 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
638+
@@ -2545,7 +2756,13 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
625639
//
626640
// slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled.
627641
// This works out that each fifo posting QP gets drained
@@ -636,7 +650,7 @@ index 9bfd8dcf..4d3f0a08 100644
636650
wr.send_flags |= IBV_SEND_SIGNALED;
637651
wr.wr_id = req - comm->base.reqs;
638652
ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base);
639-
@@ -2560,10 +2763,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
653+
@@ -2560,10 +2777,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz
640654

641655
comm->remFifo.fifoTail++;
642656

@@ -653,7 +667,7 @@ index 9bfd8dcf..4d3f0a08 100644
653667
struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm;
654668
if (comm->base.ready == 0) {
655669
WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0");
656-
@@ -2573,6 +2782,11 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
670+
@@ -2573,6 +2796,11 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
657671
if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError;
658672
NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__));
659673

@@ -665,7 +679,7 @@ index 9bfd8dcf..4d3f0a08 100644
665679
struct ncclIbRequest* req;
666680
NCCLCHECK(ncclIbGetRequest(&comm->base, &req));
667681
req->type = NCCL_NET_IB_REQ_RECV;
668-
@@ -2586,50 +2800,65 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
682+
@@ -2586,50 +2814,65 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int*
669683
req->devBases[i] = &comm->devs[i].base;
670684
}
671685

@@ -763,7 +777,7 @@ index 9bfd8dcf..4d3f0a08 100644
763777
}
764778

765779
ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) {
766-
@@ -2698,6 +2927,8 @@ static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) {
780+
@@ -2698,6 +2941,8 @@ static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) {
767781
}
768782
#endif
769783

@@ -772,7 +786,7 @@ index 9bfd8dcf..4d3f0a08 100644
772786
ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
773787
struct ncclIbRequest *r = (struct ncclIbRequest*)request;
774788
*done = 0;
775-
@@ -2731,13 +2962,18 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
789+
@@ -2731,13 +2976,18 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
776790

777791
int totalWrDone = 0;
778792
int wrDone = 0;
@@ -793,7 +807,7 @@ index 9bfd8dcf..4d3f0a08 100644
793807
totalWrDone += wrDone;
794808
if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); }
795809
if (wrDone == 0) continue;
796-
@@ -2889,7 +3125,7 @@ ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p) {
810+
@@ -2889,7 +3139,7 @@ ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p) {
797811
}
798812

799813
ncclNet_t ncclNetIb = {

projects/rccl/src/transport/net_ib_rocm.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,20 @@ ncclResult_t rocmIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config
986986
// for AINIC, these params are defaulted to enabled unless user forces it to disable(0).
987987
rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true);
988988
rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true);
989+
990+
// CTS Offload and CTS Inline are not yet compatible with NIC Fusion
991+
// (NCCL_IB_MERGE_NICS). Temporarily disable them when merge is enabled.
992+
if (ncclParamRocmIbMergeNics()) {
993+
if (rcclCtsInlineData) {
994+
INFO(NCCL_INIT|NCCL_NET, "NET/IB : NIC Fusion enabled - disabling CTS Inline Data (not yet supported with merge)");
995+
rcclCtsInlineData = false;
996+
}
997+
if (rcclCtsOffloadEnabled) {
998+
INFO(NCCL_INIT|NCCL_NET, "NET/IB : NIC Fusion enabled - disabling CTS Offload (not yet supported with merge)");
999+
rcclCtsOffloadEnabled = false;
1000+
}
1001+
}
1002+
9891003
// for AINIC IbUseInline is enabled by default always
9901004
ncclIbUseInline = true;
9911005
// for AINIC GDR flush is disabled by default

0 commit comments

Comments
 (0)