Skip to content

Commit b9dd000

Browse files
Merge pull request #438 from InfiniTensor/issue/434-metax
issue/434 hccl support bf16
2 parents f9d1662 + 3bb0c93 commit b9dd000

File tree

4 files changed

+86
-4
lines changed

4 files changed

+86
-4
lines changed

src/infiniccl-test/infiniccl_test.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#define TEST_INFINI_THREAD(API__) CHECK_API_OR(API__, INFINI_STATUS_SUCCESS, return nullptr)
1212

1313
const size_t MAX_COUNT = 8ULL * 1024 * 1024;
14+
// const size_t MAX_COUNT = 512 * 1024; // for metax
1415

1516
const size_t TEST_COUNTS[] = {
1617
128,
@@ -19,7 +20,7 @@ const size_t TEST_COUNTS[] = {
1920
MAX_COUNT,
2021
};
2122

22-
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16};
23+
const infiniDtype_t TEST_DTYPES[] = {INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16};
2324

2425
const size_t WARM_UPS = 10;
2526

@@ -51,6 +52,11 @@ void setData(infiniDtype_t dtype, void *data, size_t count, float val) {
5152
((fp16_t *)data)[i] = utils::cast<fp16_t>(val);
5253
}
5354
break;
55+
case INFINI_DTYPE_BF16:
56+
for (size_t i = 0; i < count; i++) {
57+
((bf16_t *)data)[i] = utils::cast<bf16_t>(val);
58+
}
59+
break;
5460
default:
5561
std::abort();
5662
break;
@@ -67,6 +73,12 @@ int checkData(const T *actual_, const T *expected_, size_t count) {
6773
if (std::abs(actual - expected) > 1e-4) {
6874
failed += 1;
6975
}
76+
} else if constexpr (std::is_same<T, bf16_t>::value) {
77+
float actual = utils::cast<float>(actual_[i]);
78+
float expected = utils::cast<float>(expected_[i]);
79+
if (std::abs(actual - expected) > 1e-4) {
80+
failed += 1;
81+
}
7082
} else {
7183
if (std::abs(actual_[i] - expected_[i]) > 1e-4) {
7284
failed += 1;
@@ -82,6 +94,8 @@ int checkData(const void *actual, const void *expected, infiniDtype_t dtype, siz
8294
return checkData((const float *)actual, (const float *)expected, count);
8395
case INFINI_DTYPE_F16:
8496
return checkData((const fp16_t *)actual, (const fp16_t *)expected, count);
97+
case INFINI_DTYPE_BF16:
98+
return checkData((const bf16_t *)actual, (const bf16_t *)expected, count);
8599
default:
86100
std::abort();
87101
return 1;

src/infiniccl/metax/infiniccl_metax.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ inline hcclDataType_t getHcclDtype(infiniDtype_t datatype) {
2323
return hcclFloat;
2424
case INFINI_DTYPE_F16:
2525
return hcclHalf;
26+
case INFINI_DTYPE_BF16:
27+
return hcclBfloat16;
2628
default:
2729
std::abort();
2830
return hcclHalf;
@@ -83,9 +85,7 @@ infiniStatus_t allReduce(
8385
infinicclComm_t comm,
8486
infinirtStream_t stream) {
8587

86-
if (datatype != INFINI_DTYPE_F32 && datatype != INFINI_DTYPE_F16) {
87-
return INFINI_STATUS_BAD_PARAM;
88-
}
88+
CHECK_DTYPE(datatype, INFINI_DTYPE_F32, INFINI_DTYPE_F16, INFINI_DTYPE_BF16);
8989

9090
CHECK_HCCL(hcclAllReduce(sendbuf, recvbuf, count, getHcclDtype(datatype),
9191
getHcclRedOp(op), getHcclComm(comm), getMacaStream(stream)));
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __SOFTPLUS_METAX_API_H__
2+
#define __SOFTPLUS_METAX_API_H__
3+
4+
#include "../../../elementwise/metax/elementwise_metax_api.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(softplus, metax)
7+
8+
#endif // __SOFTPLUS_METAX_API_H__
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include "softplus_metax.h"
2+
3+
#include "../../../elementwise/metax/elementwise_metax.h"
4+
5+
#include "../cuda/kernel.cuh"
6+
7+
namespace op::softplus::metax {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle_,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t out_desc,
15+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
16+
17+
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
18+
auto dtype = out_desc->dtype();
19+
20+
const auto &x_desc = input_desc_vec.at(0);
21+
const auto &y_shape = out_desc->shape();
22+
const auto &x_shape = x_desc->shape();
23+
24+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16);
25+
26+
CHECK_SAME_SHAPE(y_shape, x_shape);
27+
28+
// create METAX elementwise descriptor
29+
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
30+
31+
return INFINI_STATUS_SUCCESS;
32+
}
33+
34+
infiniStatus_t Descriptor::calculate(
35+
void *workspace,
36+
size_t workspace_size,
37+
void *output,
38+
std::vector<const void *> inputs,
39+
void *stream) const {
40+
41+
if (workspace_size < _workspace_size) {
42+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
43+
}
44+
45+
switch (_dtype) {
46+
case INFINI_DTYPE_F16:
47+
return _device_info->calculate<256, cuda::SoftplusOp, half>(_info, workspace, output, inputs, stream);
48+
case INFINI_DTYPE_BF16:
49+
return _device_info->calculate<256, cuda::SoftplusOp, cuda_bfloat16>(_info, workspace, output, inputs, stream);
50+
case INFINI_DTYPE_F32:
51+
return _device_info->calculate<256, cuda::SoftplusOp, float>(_info, workspace, output, inputs, stream);
52+
case INFINI_DTYPE_F64:
53+
return _device_info->calculate<256, cuda::SoftplusOp, double>(_info, workspace, output, inputs, stream);
54+
default:
55+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
56+
}
57+
58+
return INFINI_STATUS_SUCCESS;
59+
}
60+
} // namespace op::softplus::metax

0 commit comments

Comments
 (0)