Skip to content

Commit 97f9ac7

Browse files
Merge pull request #435 from InfiniTensor/issue/434-nv
issue/434 nccl support bf16
2 parents e8e25a2 + 81093e0 commit 97f9ac7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/infiniccl/cuda/infiniccl_cuda.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ inline ncclDataType_t getNcclDtype(infiniDtype_t datatype) {
2222
return ncclFloat;
2323
case INFINI_DTYPE_F16:
2424
return ncclHalf;
25+
case INFINI_DTYPE_BF16:
26+
return ncclBfloat16;
2527
default:
2628
std::abort();
2729
return ncclHalf;
@@ -82,9 +84,7 @@ infiniStatus_t allReduce(
8284
infinicclComm_t comm,
8385
infinirtStream_t stream) {
8486

85-
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
86-
return INFINI_STATUS_BAD_PARAM;
87-
}
87+
CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
8888

8989
CHECK_NCCL(ncclAllReduce(sendbuf, recvbuf, count, getNcclDtype(datatype),
9090
getNcclRedOp(op), getNcclComm(comm), getCudaStream(stream)));

0 commit comments

Comments
 (0)