Skip to content

Commit e20c000

Browse files
Merge pull request #386 from InfiniTensor/issue-385
Issue/385 p800 上rmsnorm重构,支持多精度
2 parents c24a52e + a2a463d commit e20c000

File tree

7 files changed

+305
-239
lines changed

7 files changed

+305
-239
lines changed

src/infiniop/devices/kunlun/kunlun_kernel_common.h

Lines changed: 83 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,25 @@
44
// This header file will only be include by .xpu file
55
#include "xpu/runtime.h"
66
#include <xpu/kernel/xtdk.h>
7+
#include <xpu/kernel/xtdk_atomic_sm_xpu3.h>
78
#include <xpu/kernel/xtdk_bf16.h>
89
#include <xpu/kernel/xtdk_math.h>
910
#include <xpu/kernel/xtdk_simd.h>
11+
#include <xpu/kernel/xtdk_trigonometric.h>
1012

1113
namespace device::kunlun::kernel {
1214

15+
#define SM_SIZE 10240
16+
17+
/**
18+
* @brief Define ptrdiff_t and size_t for kunlun xpu
19+
* ptrdiff_t is 32 bit, size_t is 32 bit in xpu kernel
20+
* We padding it into 64 bit for convience of DATACOPY
21+
*/
1322
typedef struct _ptrdiff_t {
1423
int32_t value; // 32 bit
1524
int32_t padding; // 32 bit
1625
} _ptrdiff_t;
17-
1826
// same as ptrdiff
1927
typedef struct _size_t {
2028
uint32_t value;
@@ -29,17 +37,83 @@ inline __device__ float lowerBitMask(int i) {
2937
return (1 << (i + 1)) - 1;
3038
}
3139

32-
// Atomic add for reduce
33-
inline __device__ void atomicAddF32(__shared_ptr__ float *ptr, float value) {
34-
int success = 1;
35-
while (success) {
36-
// SM2REG read 32bit data to register
37-
float a = SM2REG_atomic(ptr);
38-
a = a + value;
39-
success = REG2SM_atomic(ptr, a);
40+
/**
41+
* @brief Load data from shared memory
42+
* @param p: pointer to shared memory
43+
* @return loaded value
44+
*/
45+
template <typename T>
46+
__device__ inline T loadsm(__shared_ptr__ const T *p) {
47+
T v;
48+
if constexpr (std::is_same<T, half>::value
49+
|| std::is_same<T, bfloat16_t>::value) {
50+
__builtin_memcpy(&v, p, sizeof(T));
51+
} else {
52+
v = *p;
53+
}
54+
return v;
55+
}
56+
// Load len data from shared memory
57+
template <typename T>
58+
__device__ inline void loadsm(__shared_ptr__ const T *p, T *v, int len) {
59+
__builtin_memcpy(v, p, len * sizeof(T));
60+
}
61+
62+
/**
63+
* @brief Convert data type. All data is in local memory
64+
* @param v: input value
65+
* @return output value
66+
*/
67+
template <typename Tout, typename Tin>
68+
__device__ inline Tout to(Tin v) {
69+
if constexpr (std::is_same<Tin, half>::value) {
70+
return __half2float(v);
71+
} else if constexpr (std::is_same<Tin, bfloat16_t>::value) {
72+
return __bfloat162float(v);
73+
} else {
74+
return static_cast<Tout>(v);
4075
}
4176
}
4277

78+
/**
79+
* @brief atomicAdd for kunlun xpu
80+
* @param ptr: pointer to shared memory
81+
* @param value: value to add
82+
*/
83+
template <typename T>
84+
inline __device__ T atomicAdd(__shared_ptr__ T *ptr, T value) {
85+
T x = atomicadd(ptr, value);
86+
return x;
87+
}
88+
// Specialize atomicAdd for half
89+
template <>
90+
inline __device__ half atomicAdd<half>(__shared_ptr__ half *ptr, half value) {
91+
ticket_lock_mix();
92+
__half old = loadsm(ptr);
93+
float of = __half2float(old);
94+
float vf = __half2float(value);
95+
float sumf = of + vf;
96+
half sum = __float2half_rn(sumf);
97+
*ptr = sum;
98+
mfence_sm();
99+
ticket_unlock_mix();
100+
return old;
101+
}
102+
// Specialize atomicAdd for bfloat16_t
103+
template <>
104+
inline __device__ bfloat16_t atomicAdd<bfloat16_t>(__shared_ptr__ bfloat16_t *ptr, bfloat16_t value) {
105+
ticket_lock_mix();
106+
bfloat16_t old = loadsm(ptr);
107+
float of = __bfloat162float(old);
108+
float vf = __bfloat162float(value);
109+
float sumf = of + vf;
110+
bfloat16_t sum = __float2bfloat16_rn(sumf);
111+
*ptr = sum;
112+
mfence_sm();
113+
ticket_unlock_mix();
114+
return old;
115+
}
116+
43117
/**
44118
* @brief Get index of broadcasted input
45119
* flat_index: flatten index of output tensor
@@ -85,5 +159,3 @@ inline __device__ int indexToOffset(
85159
} // namespace device::kunlun::kernel
86160

87161
#endif // __INFINIOP_KUNLUN_KERNEL_COMMON_H__
88-
// TODO: atomicAddF16
89-
// TODO: atomicAddI8
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#ifndef __RMS_NORM_KUNLUN_KERNEL_H__
2+
#define __RMS_NORM_KUNLUN_KERNEL_H__
3+
4+
#include "../../../devices/kunlun/kunlun_kernel_common.h"
5+
#include "../../../reduce/kunlun/reduce_kunlun.h"
6+
7+
using namespace device::kunlun::kernel;
8+
9+
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
10+
__device__ void rmsnormBlock(
11+
__shared_ptr__ Tdata *y,
12+
__shared_ptr__ const Tdata *x,
13+
__shared_ptr__ const Tweight *w,
14+
size_t dim,
15+
float epsilon) {
16+
17+
// Block reduce sum of x^2
18+
Tcompute ss = op::common_kunlun::reduce_op::sumSquared<BLOCK_SIZE, Tdata, Tcompute>(x, dim);
19+
20+
__shared__ Tcompute rms;
21+
if (core_id() == 0) {
22+
rms = Tcompute(rsqrt(ss / Tcompute(dim) + epsilon));
23+
}
24+
sync_cluster();
25+
26+
// Copy contiguous x, w into local mem (load from shared memory safely)
27+
for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) {
28+
Tdata xi = loadsm(x + i);
29+
Tweight wi = loadsm(w + i);
30+
y[i] = static_cast<Tdata>(to<Tcompute>(xi) * to<Tcompute>(wi) * rms);
31+
}
32+
sync_cluster();
33+
}
34+
35+
#endif

src/infiniop/ops/rms_norm/kunlun/rms_norm_kernel.xpu

Lines changed: 0 additions & 125 deletions
This file was deleted.

src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.cc

Lines changed: 0 additions & 79 deletions
This file was deleted.

src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,14 @@
55

66
DESCRIPTOR(kunlun)
77

8+
#define INSTANTIATE_RMSNORM_KERNEL(BLOCK_SIZE, Tcompute, Tdata, Tweight) \
9+
template __global__ void rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight>( \
10+
Tdata * y, \
11+
int32_t stride_y, \
12+
const Tdata *x, \
13+
int32_t stride_x, \
14+
const Tweight *w, \
15+
uint32_t dim, \
16+
float epsilon);
17+
818
#endif

0 commit comments

Comments
 (0)