Skip to content

Commit 9b758b9

Browse files
Merge pull request #382 from InfiniTensor/issue/375
Issue/375: 摩尔平台 main 分支中 GEMM 添加 BF16 浮点数计算支持,并将摩尔平台部分文件中 “musa” 重命名为 "moore"
2 parents 831021b + 66a8eb9 commit 9b758b9

File tree

19 files changed

+192
-188
lines changed

19 files changed

+192
-188
lines changed

src/infiniop/devices/handle.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "ascend/ascend_handle.h"
1616
#endif
1717
#ifdef ENABLE_MOORE_API
18-
#include "musa/musa_handle.h"
18+
#include "moore/moore_handle.h"
1919
#endif
2020
#ifdef ENABLE_KUNLUN_API
2121
#include "kunlun/kunlun_handle.h"
@@ -54,7 +54,7 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) {
5454
CREATE(INFINI_DEVICE_ASCEND, ascend);
5555
#endif
5656
#ifdef ENABLE_MOORE_API
57-
CREATE(INFINI_DEVICE_MOORE, musa);
57+
CREATE(INFINI_DEVICE_MOORE, moore);
5858
#endif
5959
#ifdef ENABLE_KUNLUN_API
6060
CREATE(INFINI_DEVICE_KUNLUN, kunlun);
@@ -94,7 +94,7 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
9494
DELETE(INFINI_DEVICE_ASCEND, ascend);
9595
#endif
9696
#ifdef ENABLE_MOORE_API
97-
DELETE(INFINI_DEVICE_MOORE, musa);
97+
DELETE(INFINI_DEVICE_MOORE, moore);
9898
#endif
9999
#ifdef ENABLE_KUNLUN_API
100100
DELETE(INFINI_DEVICE_KUNLUN, kunlun);

src/infiniop/devices/musa/common_musa.h renamed to src/infiniop/devices/moore/moore_common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "../../../utils.h"
22
#include "../pool.h"
3-
#include "musa_handle.h"
3+
#include "moore_handle.h"
44
#include <mublas.h>
55
#include <mudnn.h>
66
#include <musa.h>
@@ -10,7 +10,7 @@
1010
#define CHECK_MUBLAS(API) CHECK_INTERNAL(API, MUBLAS_STATUS_SUCCESS)
1111
#define CHECK_MUDNN(API) CHECK_INTERNAL((int)API, (int)::musa::dnn::Status::SUCCESS)
1212

13-
namespace device::musa {
13+
namespace device::moore {
1414

1515
class Handle::Internal {
1616
Pool<std::unique_ptr<mublasHandle_t>> mublas_handles;
@@ -39,4 +39,4 @@ class Handle::Internal {
3939
int gridSizeZ() const;
4040
};
4141

42-
} // namespace device::musa
42+
} // namespace device::moore

src/infiniop/devices/musa/musa_handle.cc renamed to src/infiniop/devices/moore/moore_handle.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include "common_musa.h"
1+
#include "moore_common.h"
22

3-
namespace device::musa {
3+
namespace device::moore {
44
Handle::Handle(infiniDevice_t device, int device_id)
55
: InfiniopHandle{device, device_id},
66
_internal(std::make_shared<Handle::Internal>(device_id)) {}
@@ -67,4 +67,4 @@ infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
6767
return INFINI_STATUS_SUCCESS;
6868
}
6969

70-
} // namespace device::musa
70+
} // namespace device::moore

src/infiniop/devices/musa/musa_handle.h renamed to src/infiniop/devices/moore/moore_handle.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
#ifndef __INFINIOP_MUSA_HANDLE_H__
2-
#define __INFINIOP_MUSA_HANDLE_H__
1+
#ifndef __INFINIOP_MOORE_HANDLE_H__
2+
#define __INFINIOP_MOORE_HANDLE_H__
33

44
#include "../../handle.h"
55
#include <memory>
66

7-
namespace device::musa {
7+
namespace device::moore {
88
struct Handle : public InfiniopHandle {
99
Handle(int device_id);
1010
class Internal;
@@ -20,6 +20,6 @@ struct Handle : public InfiniopHandle {
2020
std::shared_ptr<Internal> _internal;
2121
};
2222

23-
} // namespace device::musa
23+
} // namespace device::moore
2424

25-
#endif // __INFINIOP_MUSA_HANDLE_H__
25+
#endif // __INFINIOP_MOORE_HANDLE_H__

src/infiniop/devices/musa/musa_kernel_common.h renamed to src/infiniop/devices/moore/moore_kernel_common.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
#define INFINIOP_MUSA_KERNEL __global__ void
1+
#define INFINIOP_MOORE_KERNEL __global__ void
22

33
#include <musa_bf16.h>
44
#include <musa_fp16.h>
55

66
// Posible maximum number of threads per block for MUSA architectures
77
// Used for picking correct kernel launch configuration
8-
#define MUSA_BLOCK_SIZE_2048 2048
9-
#define MUSA_BLOCK_SIZE_1024 1024
10-
#define MUSA_BLOCK_SIZE_512 512
8+
#define MOORE_BLOCK_SIZE_2048 2048
9+
#define MOORE_BLOCK_SIZE_1024 1024
10+
#define MOORE_BLOCK_SIZE_512 512
1111

12-
#define CHECK_MUSA(API) CHECK_INTERNAL(API, musaSuccess)
12+
#define CHECK_MOORE(API) CHECK_INTERNAL(API, musaSuccess)
1313

1414
using musa_bfloat16 = mt_bfloat16;
1515
using musa_bfloat162 = mt_bfloat162;
1616

17-
namespace device::musa {
17+
namespace device::moore {
1818

1919
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
2020
__forceinline__ __device__ __host__ size_t
@@ -45,7 +45,7 @@ indexToOffset(
4545
}
4646
return res;
4747
}
48-
} // namespace device::musa
48+
} // namespace device::moore
4949

5050
__forceinline__ __device__ float
5151
exp_(const float val) {
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __GEMM_MOORE_H__
2+
#define __GEMM_MOORE_H__
3+
4+
#include "../gemm.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __GEMM_MOORE_H__
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "../../../devices/moore/moore_common.h"
2+
#include "../../../devices/moore/moore_handle.h"
3+
#include "gemm_moore.h"
4+
5+
namespace op::gemm::moore {
6+
7+
struct Descriptor::Opaque {
8+
std::shared_ptr<device::moore::Handle::Internal> internal;
9+
};
10+
11+
Descriptor::~Descriptor() {
12+
delete _opaque;
13+
}
14+
15+
infiniStatus_t Descriptor::create(
16+
infiniopHandle_t handle_,
17+
Descriptor **desc_ptr,
18+
infiniopTensorDescriptor_t c_desc,
19+
infiniopTensorDescriptor_t a_desc,
20+
infiniopTensorDescriptor_t b_desc) {
21+
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
22+
auto dtype = c_desc->dtype();
23+
24+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
25+
26+
auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR);
27+
CHECK_RESULT(result);
28+
29+
*desc_ptr = new Descriptor(
30+
dtype, result.take(), 0,
31+
new Opaque{handle->internal()},
32+
handle->device, handle->device_id);
33+
return INFINI_STATUS_SUCCESS;
34+
}
35+
36+
infiniStatus_t Descriptor::calculate(
37+
void *workspace,
38+
size_t workspace_size,
39+
void *c,
40+
float beta,
41+
const void *a,
42+
const void *b,
43+
float alpha,
44+
void *stream) const {
45+
46+
musaDataType a_type, b_type, c_type;
47+
mublasComputeType_t compute_type;
48+
49+
// MUSA's GEMM operations require that the scalar values alpha and beta have the same data type as the matrices.
50+
// This ensures correct computation during the muBLAS GEMM operation.
51+
// Declare half-precision variables to handle F16 types.
52+
half alpha_h, beta_h;
53+
54+
// Initialize generic void pointers for alpha and beta.
55+
// They point to the original float values
56+
// It will be used directly when the GEMM operation is performed with F32 data.
57+
const void *p_alpha = &alpha;
58+
const void *p_beta = &beta;
59+
60+
switch (_dtype) {
61+
case INFINI_DTYPE_F16:
62+
a_type = b_type = c_type = MUSA_R_16F;
63+
compute_type = MUBLAS_COMPUTE_16F;
64+
65+
// Convert alpha/beta to half-precision and update the pointers.
66+
alpha_h = __float2half(alpha);
67+
beta_h = __float2half(beta);
68+
p_alpha = &alpha_h;
69+
p_beta = &beta_h;
70+
71+
break;
72+
case INFINI_DTYPE_BF16:
73+
a_type = b_type = c_type = MUSA_R_16BF;
74+
compute_type = MUBLAS_COMPUTE_32F;
75+
break;
76+
case INFINI_DTYPE_F32:
77+
a_type = b_type = c_type = MUSA_R_32F;
78+
compute_type = MUBLAS_COMPUTE_32F_FAST_TF32;
79+
break;
80+
81+
default:
82+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
83+
}
84+
85+
if (_info.is_transed) {
86+
std::swap(a, b);
87+
}
88+
89+
auto op_a = _info.a_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
90+
auto op_b = _info.b_matrix.row_stride == 1 ? MUBLAS_OP_N : MUBLAS_OP_T;
91+
92+
CHECK_STATUS(_opaque->internal->useMublas(
93+
(musaStream_t)stream,
94+
[&](mublasHandle_t handle) {
95+
CHECK_MUBLAS(
96+
mublasGemmStridedBatchedEx(
97+
handle,
98+
op_a,
99+
op_b,
100+
static_cast<int>(_info.m),
101+
static_cast<int>(_info.n),
102+
static_cast<int>(_info.k),
103+
p_alpha,
104+
a,
105+
a_type,
106+
static_cast<int>(_info.a_matrix.ld()),
107+
_info.a_matrix.stride,
108+
b,
109+
b_type,
110+
static_cast<int>(_info.b_matrix.ld()),
111+
_info.b_matrix.stride,
112+
p_beta,
113+
c,
114+
c_type,
115+
static_cast<int>(_info.c_matrix.ld()),
116+
_info.c_matrix.stride,
117+
static_cast<int>(_info.batch),
118+
compute_type,
119+
MUBLAS_GEMM_DEFAULT));
120+
return INFINI_STATUS_SUCCESS;
121+
}));
122+
return INFINI_STATUS_SUCCESS;
123+
}
124+
125+
} // namespace op::gemm::moore

src/infiniop/ops/gemm/musa/gemm_musa.h

Lines changed: 0 additions & 8 deletions
This file was deleted.

src/infiniop/ops/gemm/musa/gemm_musa.mu

Lines changed: 0 additions & 121 deletions
This file was deleted.

0 commit comments

Comments
 (0)