Skip to content

Commit 2e6241e

Browse files
authored
Merge pull request InfiniTensor#2 from gofreelee/mooer/dev
layernorm for moore
2 parents 441e419 + 5e58311 commit 2e6241e

File tree

3 files changed

+249
-1
lines changed

3 files changed

+249
-1
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#ifndef __LAYER_NORM_MOORE_H__
2+
#define __LAYER_NORM_MOORE_H__
3+
4+
#include "../layer_norm.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __LAYER_NORM_MOORE_H__
9+
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#include "../../../devices/moore/moore_common.h"
2+
#include "../../../devices/moore/moore_handle.h"
3+
#include "../../../devices/moore/moore_kernel_common.h"
4+
5+
#include "../../../reduce/cuda/reduce.cuh"
6+
#include "../info.h"
7+
#include "layer_norm_moore.h"
8+
9+
#include <cub/block/block_reduce.cuh>
10+
11+
namespace op::layer_norm::moore {
12+
13+
struct Descriptor::Opaque {
14+
std::shared_ptr<device::moore::Handle::Internal> internal;
15+
};
16+
17+
Descriptor::~Descriptor() {
18+
delete _opaque;
19+
}
20+
21+
template <unsigned int BLOCK_SIZE, typename T>
22+
INFINIOP_MOORE_KERNEL layernormOutputKernel(
23+
T *__restrict__ output,
24+
const T *__restrict__ input,
25+
const T *__restrict__ weight,
26+
const T *__restrict__ bias,
27+
float eps,
28+
int dimsize,
29+
const ptrdiff_t *__restrict__ output_strides,
30+
const ptrdiff_t *__restrict__ input_strides,
31+
const size_t *__restrict__ shape,
32+
ptrdiff_t weight_stride,
33+
ptrdiff_t bias_stride,
34+
int ndim,
35+
bool bias_exist) {
36+
int ind_i = 0;
37+
int ind_o = 0;
38+
39+
int tid = (int)blockIdx.x;
40+
for (int j = ndim - 2; j >= 0; j--) {
41+
int idx = tid % (int)shape[j];
42+
ind_i += idx * (int)input_strides[j];
43+
ind_o += idx * (int)output_strides[j];
44+
tid = tid / (int)shape[j];
45+
}
46+
47+
float mu_partial = op::common_cuda::reduce_op::sum<BLOCK_SIZE, T, float>(
48+
input + ind_i,
49+
(size_t)dimsize)
50+
/ (float)dimsize;
51+
__shared__ float mu;
52+
if (threadIdx.x == 0) {
53+
mu = mu_partial;
54+
}
55+
__syncthreads();
56+
57+
float sigma2_partial = 0.0f;
58+
for (int id = (int)threadIdx.x; id < dimsize; id += (int)BLOCK_SIZE) {
59+
float v = static_cast<float>(input[ind_i + id]) - mu;
60+
sigma2_partial += v * v;
61+
}
62+
63+
using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
64+
__shared__ typename BlockReduce::TempStorage temp_storage;
65+
float sigma2_sum = BlockReduce(temp_storage).Sum(sigma2_partial);
66+
67+
__shared__ float inv_std;
68+
if (threadIdx.x == 0) {
69+
float sigma_tmp = sqrtf(sigma2_sum * __fdividef(1.0F, (float)dimsize) + eps);
70+
inv_std = __fdividef(1.0F, sigma_tmp);
71+
}
72+
__syncthreads();
73+
74+
for (int id = (int)threadIdx.x; id < dimsize; id += (int)BLOCK_SIZE) {
75+
float w = static_cast<float>(weight[id * weight_stride]);
76+
float b = bias_exist ? static_cast<float>(bias[id * bias_stride]) : 0.0f;
77+
float x = static_cast<float>(input[ind_i + id]);
78+
float y = w * (x - mu) * inv_std + b;
79+
output[ind_o + id] = static_cast<T>(y);
80+
}
81+
}
82+
83+
template <unsigned int BLOCK_SIZE, typename T>
84+
infiniStatus_t calculate_layer_norm(
85+
const LayerNormInfo &info,
86+
T *output,
87+
const T *input,
88+
const T *weight,
89+
const T *bias,
90+
musaStream_t stream,
91+
void *workspace) {
92+
size_t ndim = info.ndim;
93+
char *workspace_ptr = reinterpret_cast<char *>(workspace);
94+
95+
ptrdiff_t *input_strides_dev = reinterpret_cast<ptrdiff_t *>(workspace_ptr);
96+
ptrdiff_t *output_strides_dev = input_strides_dev + ndim;
97+
ptrdiff_t *input_standardization_strides_dev = output_strides_dev + ndim;
98+
ptrdiff_t *input_std_deviation_strides_dev = input_standardization_strides_dev + ndim;
99+
100+
size_t ptrdiff_array_size = 4 * ndim * sizeof(ptrdiff_t);
101+
size_t *shape_dev = reinterpret_cast<size_t *>(workspace_ptr + ptrdiff_array_size);
102+
103+
CHECK_MOORE(musaMemcpyAsync(input_strides_dev, info.input_strides.data(), sizeof(ptrdiff_t) * ndim, musaMemcpyHostToDevice, stream));
104+
CHECK_MOORE(musaMemcpyAsync(output_strides_dev, info.output_strides.data(), sizeof(ptrdiff_t) * ndim, musaMemcpyHostToDevice, stream));
105+
CHECK_MOORE(musaMemcpyAsync(input_standardization_strides_dev, info.input_standardization_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), musaMemcpyHostToDevice, stream));
106+
CHECK_MOORE(musaMemcpyAsync(input_std_deviation_strides_dev, info.input_std_deviation_strides.data(), sizeof(ptrdiff_t) * (ndim - 1), musaMemcpyHostToDevice, stream));
107+
CHECK_MOORE(musaMemcpyAsync(shape_dev, info.input_shape.data(), sizeof(size_t) * ndim, musaMemcpyHostToDevice, stream));
108+
109+
int dimsize = (int)info.normalized_size;
110+
int num_blocks = (int)info.othersize;
111+
112+
layernormOutputKernel<BLOCK_SIZE, T>
113+
<<<num_blocks, BLOCK_SIZE, 0, stream>>>(
114+
output,
115+
input,
116+
weight,
117+
bias,
118+
info.eps,
119+
dimsize,
120+
output_strides_dev,
121+
input_strides_dev,
122+
shape_dev,
123+
info.weight_strides[0],
124+
info.bias_exist ? info.bias_strides[0] : 0,
125+
(int)info.ndim,
126+
info.bias_exist);
127+
128+
return INFINI_STATUS_SUCCESS;
129+
}
130+
131+
infiniStatus_t Descriptor::create(
132+
infiniopHandle_t handle_,
133+
Descriptor **desc_ptr,
134+
infiniopTensorDescriptor_t output_desc,
135+
infiniopTensorDescriptor_t input_standardization_desc,
136+
infiniopTensorDescriptor_t input_std_deviation_desc,
137+
infiniopTensorDescriptor_t input_desc,
138+
infiniopTensorDescriptor_t weight_desc,
139+
infiniopTensorDescriptor_t bias_desc,
140+
float eps) {
141+
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
142+
143+
auto dtype = output_desc->dtype();
144+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
145+
146+
auto result = LayerNormInfo::createLayerNormInfo(
147+
output_desc,
148+
input_standardization_desc,
149+
input_std_deviation_desc,
150+
input_desc,
151+
weight_desc,
152+
bias_desc,
153+
eps);
154+
CHECK_RESULT(result);
155+
auto info = result.take();
156+
157+
size_t workspace_size = output_desc->ndim() * (sizeof(ptrdiff_t) * 4 + sizeof(size_t));
158+
159+
*desc_ptr = new Descriptor(
160+
dtype,
161+
std::move(info),
162+
workspace_size,
163+
new Opaque{handle->internal()},
164+
handle->device,
165+
handle->device_id);
166+
167+
return INFINI_STATUS_SUCCESS;
168+
}
169+
170+
infiniStatus_t Descriptor::calculate(
171+
void *workspace,
172+
size_t workspace_size,
173+
void *output,
174+
void *input_standardization,
175+
void *input_std_deviation,
176+
const void *input,
177+
const void *weight,
178+
const void *bias,
179+
void *stream_) const {
180+
if (workspace_size < _workspace_size) {
181+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
182+
}
183+
184+
(void)input_standardization;
185+
(void)input_std_deviation;
186+
187+
musaStream_t stream = (musaStream_t)stream_;
188+
189+
#define CALC(BLOCK_SIZE, TDATA) \
190+
calculate_layer_norm<BLOCK_SIZE, TDATA>(_info, (TDATA *)output, (const TDATA *)input, (const TDATA *)weight, (const TDATA *)bias, stream, workspace)
191+
192+
// Some MUSA targets report maxThreadsPerBlock() == 2048, but a 2048-thread BlockReduce
193+
// can exceed the shared-memory limit. Clamp to 1024/512 for compatibility.
194+
int max_threads = _opaque->internal->maxThreadsPerBlock();
195+
unsigned int block_size = (max_threads >= (int)MOORE_BLOCK_SIZE_1024) ? MOORE_BLOCK_SIZE_1024 : MOORE_BLOCK_SIZE_512;
196+
197+
if (block_size == MOORE_BLOCK_SIZE_1024) {
198+
if (_info.dtype == INFINI_DTYPE_F16) {
199+
return CALC(MOORE_BLOCK_SIZE_1024, half);
200+
} else if (_info.dtype == INFINI_DTYPE_F32) {
201+
return CALC(MOORE_BLOCK_SIZE_1024, float);
202+
} else if (_info.dtype == INFINI_DTYPE_BF16) {
203+
return CALC(MOORE_BLOCK_SIZE_1024, __mt_bfloat16);
204+
} else {
205+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
206+
}
207+
} else if (block_size == MOORE_BLOCK_SIZE_512) {
208+
if (_info.dtype == INFINI_DTYPE_F16) {
209+
return CALC(MOORE_BLOCK_SIZE_512, half);
210+
} else if (_info.dtype == INFINI_DTYPE_F32) {
211+
return CALC(MOORE_BLOCK_SIZE_512, float);
212+
} else if (_info.dtype == INFINI_DTYPE_BF16) {
213+
return CALC(MOORE_BLOCK_SIZE_512, __mt_bfloat16);
214+
} else {
215+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
216+
}
217+
} else {
218+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
219+
}
220+
221+
#undef CALC
222+
}
223+
224+
} // namespace op::layer_norm::moore

src/infiniop/ops/layer_norm/operator.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#ifdef ENABLE_METAX_API
1313
#include "metax/layer_norm_metax.h"
1414
#endif
15+
#ifdef ENABLE_MOORE_API
16+
#include "moore/layer_norm_moore.h"
17+
#endif
1518

1619
__C infiniStatus_t infiniopCreateLayerNormDescriptor(
1720
infiniopHandle_t handle,
@@ -53,6 +56,9 @@ __C infiniStatus_t infiniopCreateLayerNormDescriptor(
5356
#ifdef ENABLE_METAX_API
5457
CREATE(INFINI_DEVICE_METAX, metax);
5558
#endif
59+
#ifdef ENABLE_MOORE_API
60+
CREATE(INFINI_DEVICE_MOORE, moore);
61+
#endif
5662

5763
default:
5864
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -82,6 +88,9 @@ __C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor
8288
#endif
8389
#ifdef ENABLE_METAX_API
8490
GET(INFINI_DEVICE_METAX, metax);
91+
#endif
92+
#ifdef ENABLE_MOORE_API
93+
GET(INFINI_DEVICE_MOORE, moore);
8594
#endif
8695
default:
8796
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -133,6 +142,9 @@ __C infiniStatus_t infiniopLayerNorm(
133142
#ifdef ENABLE_METAX_API
134143
CALCULATE(INFINI_DEVICE_METAX, metax);
135144
#endif
145+
#ifdef ENABLE_MOORE_API
146+
CALCULATE(INFINI_DEVICE_MOORE, moore);
147+
#endif
136148

137149
default:
138150
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
@@ -163,10 +175,13 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) {
163175
#ifdef ENABLE_METAX_API
164176
DELETE(INFINI_DEVICE_METAX, metax);
165177
#endif
178+
#ifdef ENABLE_MOORE_API
179+
DELETE(INFINI_DEVICE_MOORE, moore);
180+
#endif
166181

167182
default:
168183
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
169184
}
170185

171186
#undef DELETE
172-
}
187+
}

0 commit comments

Comments
 (0)