Skip to content

Commit 19f1322

Browse files
committed
摩尔添加:softmax(含bf16)
1 parent f6198d6 commit 19f1322

File tree

5 files changed

+196
-1
lines changed

5 files changed

+196
-1
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __SOFTMAX_MOORE_H__
2+
#define __SOFTMAX_MOORE_H__
3+
4+
#include "../softmax.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "../../../devices/moore/moore_common.h"
2+
#include "softmax_moore.h"
3+
4+
#include <cub/block/block_reduce.cuh>
5+
#include "../../../devices/moore/moore_kernel_common.h"
6+
7+
#include "../../../reduce/cuda/reduce.cuh"
8+
9+
#include "softmax_moore_kernel.h"
10+
11+
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
12+
INFINIOP_MOORE_KERNEL softmax_kernel(
13+
Tdata *y, const Tdata *x,
14+
size_t othersize, size_t dimsize, ptrdiff_t stride) {
15+
softmaxKernel<BLOCK_SIZE, Tdata, Tcompute>(y, x, othersize, dimsize, stride);
16+
}
17+
18+
namespace op::softmax::moore {
19+
20+
struct Descriptor::Opaque {
21+
std::shared_ptr<device::moore::Handle::Internal> internal;
22+
};
23+
24+
Descriptor::~Descriptor() {
25+
delete _opaque;
26+
}
27+
28+
infiniStatus_t Descriptor::create(
29+
infiniopHandle_t handle,
30+
Descriptor **desc_ptr,
31+
infiniopTensorDescriptor_t y_desc,
32+
infiniopTensorDescriptor_t x_desc,
33+
int axis) {
34+
auto info = SoftmaxInfo::create(y_desc, x_desc, axis);
35+
CHECK_RESULT(info);
36+
*desc_ptr = new Descriptor(
37+
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
38+
info.take(), 0, handle->device, handle->device_id);
39+
return INFINI_STATUS_SUCCESS;
40+
}
41+
42+
template <unsigned int BLOCK_SIZE>
43+
infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype,
44+
size_t othersize, size_t dimsize, ptrdiff_t stride,
45+
musaStream_t stream) {
46+
dim3 grid(uint32_t(othersize), 1, 1);
47+
if (dtype == INFINI_DTYPE_F16) {
48+
softmax_kernel<BLOCK_SIZE, half, float>
49+
<<<grid, BLOCK_SIZE, 0, stream>>>((half *)y, (const half *)x,
50+
othersize, dimsize, stride);
51+
} else if (dtype == INFINI_DTYPE_BF16) {
52+
softmax_kernel<BLOCK_SIZE, __mt_bfloat16, float>
53+
<<<grid, BLOCK_SIZE, 0, stream>>>((__mt_bfloat16 *)y, (const __mt_bfloat16 *)x,
54+
othersize, dimsize, stride);
55+
} else if (dtype == INFINI_DTYPE_F32) {
56+
softmax_kernel<BLOCK_SIZE, float, float>
57+
<<<grid, BLOCK_SIZE, 0, stream>>>((float *)y, (const float *)x,
58+
othersize, dimsize, stride);
59+
} else {
60+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
61+
}
62+
return INFINI_STATUS_SUCCESS;
63+
}
64+
65+
infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
66+
void *y,
67+
const void *x,
68+
void *stream_) const {
69+
musaStream_t stream = (musaStream_t)stream_;
70+
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
71+
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
72+
y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream));
73+
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
74+
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
75+
y, x, _info.dtype, _info.othersize, _info.dimsize, _info.stride, stream));
76+
} else {
77+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
78+
}
79+
return INFINI_STATUS_SUCCESS;
80+
}
81+
82+
} // namespace op::softmax::moore
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#ifndef __SOFTMAX_KERNEL_CUH__
2+
#define __SOFTMAX_KERNEL_CUH__
3+
4+
template <unsigned int BLOCK_SIZE, typename Tdata, typename Tcompute>
5+
__device__ void softmaxKernel(
6+
Tdata *y_, const Tdata *x_,
7+
size_t othersize, // = outer_size * inner_size
8+
size_t dimsize, // = axis_size
9+
ptrdiff_t stride // = inner_size
10+
) {
11+
size_t other_idx = blockIdx.x;
12+
if (other_idx >= othersize) return;
13+
14+
// -----------------------------------
15+
// 正确计算 softmax slice 的 base
16+
// -----------------------------------
17+
size_t inner_idx = other_idx % stride;
18+
size_t outer_idx = other_idx / stride;
19+
20+
const Tdata *x = x_ + outer_idx * dimsize * stride + inner_idx;
21+
Tdata *y = y_ + outer_idx * dimsize * stride + inner_idx;
22+
23+
// ---------------------------
24+
// 1. block max
25+
// ---------------------------
26+
__shared__ Tcompute s_reduce[BLOCK_SIZE];
27+
__shared__ Tcompute s_max;
28+
29+
Tcompute local_max = -INFINITY;
30+
31+
for (size_t i = threadIdx.x; i < dimsize; i += BLOCK_SIZE) {
32+
Tcompute v = static_cast<Tcompute>(x[i * stride]);
33+
local_max = v > local_max ? v : local_max;
34+
}
35+
36+
s_reduce[threadIdx.x] = local_max;
37+
__syncthreads();
38+
39+
for (unsigned int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
40+
if (threadIdx.x < s) {
41+
s_reduce[threadIdx.x] =
42+
max(s_reduce[threadIdx.x], s_reduce[threadIdx.x + s]);
43+
}
44+
__syncthreads();
45+
}
46+
47+
if (threadIdx.x == 0) s_max = s_reduce[0];
48+
__syncthreads();
49+
50+
// ---------------------------
51+
// 2. exp & sum
52+
// ---------------------------
53+
Tcompute local_sum = 0;
54+
55+
for (size_t i = threadIdx.x; i < dimsize; i += BLOCK_SIZE) {
56+
Tcompute v =
57+
expf(static_cast<float>(x[i * stride]) - static_cast<float>(s_max));
58+
y[i * stride] = static_cast<Tdata>(v);
59+
local_sum += v;
60+
}
61+
62+
s_reduce[threadIdx.x] = local_sum;
63+
__syncthreads();
64+
65+
for (unsigned int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
66+
if (threadIdx.x < s) {
67+
s_reduce[threadIdx.x] += s_reduce[threadIdx.x + s];
68+
}
69+
__syncthreads();
70+
}
71+
72+
Tcompute sum = s_reduce[0];
73+
__syncthreads();
74+
75+
// ---------------------------
76+
// 3. normalize
77+
// ---------------------------
78+
for (size_t i = threadIdx.x; i < dimsize; i += BLOCK_SIZE) {
79+
y[i * stride] =
80+
static_cast<Tdata>(
81+
static_cast<float>(y[i * stride]) / static_cast<float>(sum));
82+
}
83+
}
84+
85+
86+
#endif // __SOFTMAX_KERNEL_CUH__

src/infiniop/ops/softmax/operator.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
66
#include "nvidia/softmax_nvidia.cuh"
77
#endif
8+
#ifdef ENABLE_MOORE_API
9+
#include "moore/softmax_moore.h"
10+
#endif
811

912
__C infiniStatus_t infiniopCreateSoftmaxDescriptor(
1013
infiniopHandle_t handle,
@@ -33,6 +36,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
3336
#endif
3437
#ifdef ENABLE_HYGON_API
3538
CREATE(INFINI_DEVICE_HYGON, nvidia);
39+
#endif
40+
#ifdef ENABLE_MOORE_API
41+
CREATE(INFINI_DEVICE_MOORE, moore)
3642
#endif
3743
}
3844
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -57,6 +63,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
5763
#endif
5864
#ifdef ENABLE_HYGON_API
5965
GET(INFINI_DEVICE_HYGON, nvidia);
66+
#endif
67+
#ifdef ENABLE_MOORE_API
68+
GET(INFINI_DEVICE_MOORE, moore)
6069
#endif
6170
}
6271
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -86,6 +95,9 @@ __C infiniStatus_t infiniopSoftmax(
8695
#endif
8796
#ifdef ENABLE_HYGON_API
8897
CALCULATE(INFINI_DEVICE_HYGON, nvidia);
98+
#endif
99+
#ifdef ENABLE_MOORE_API
100+
CALCULATE(INFINI_DEVICE_MOORE, moore)
89101
#endif
90102
}
91103
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -110,6 +122,9 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
110122
#endif
111123
#ifdef ENABLE_HYGON_API
112124
DESTROY(INFINI_DEVICE_HYGON, nvidia);
125+
#endif
126+
#ifdef ENABLE_MOORE_API
127+
DESTROY(INFINI_DEVICE_MOORE, moore)
113128
#endif
114129
}
115130
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;

test/infiniop/softmax.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@
3434
((1, 16, 512, 512), 1),
3535
((1, 16, 512, 512), 2),
3636
((1, 16, 512, 512), 3),
37+
((1, 32, 4096, 4096), 3), # GPT-3 / LLaMA attention
38+
((2, 16, 2048, 2048), 3),
39+
((4, 8, 1024, 1024), 3),
3740
]
3841

3942
# Data types used for testing
40-
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32]
43+
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16]
4144

4245
# Tolerance map for different data types
4346
_TOLERANCE_MAP = {
4447
InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2},
4548
InfiniDtype.F32: {"atol": 3e-5, "rtol": 1e-5},
49+
InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2},
4650
}
4751

4852

0 commit comments

Comments
 (0)