Skip to content

Commit 6120499

Browse files
committed
fast fp16 seq and one mode
Signed-off-by: daquexian <daquexian566@gmail.com>
1 parent 6b6965c commit 6120499

File tree

8 files changed

+558
-34
lines changed

8 files changed

+558
-34
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#include "ATen/ATen.h"
2+
#include <cuda_fp16.h>
3+
#include <cuda_runtime.h>
4+
#include <torch/extension.h>
5+
6+
#include "element_wise.h"
7+
#include "util.h"
8+
9+
// Equivalent Python code:
10+
// ww = t_first + k
11+
// p = torch.maximum(pp, ww)
12+
// e1 = torch.exp(pp - p)
13+
// e2 = torch.exp(ww - p)
14+
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
15+
// ww = t_decay + pp
16+
// p = torch.maximum(ww, k)
17+
// e1 = torch.exp(ww - p)
18+
// e2 = torch.exp(k - p)
19+
// t1 = e1 * aa + e2 * v
20+
// t2 = e1 * bb + e2
21+
// r = r * wkv
22+
// return t1, t2, p, r
23+
struct WkvForwardOne {
24+
const float *t_first;
25+
const float *k;
26+
const float *pp;
27+
const float *aa;
28+
const float *bb;
29+
const float *t_decay;
30+
const float *v;
31+
/* out */ float *t1;
32+
/* out */ float *t2;
33+
/* out */ float *p;
34+
/* in & out */ half *r;
35+
36+
__device__ void operator()(int i) const {
37+
float ww = t_first[i] + k[i];
38+
float pp_ = pp[i];
39+
float p_ = (pp_ > ww) ? pp_ : ww;
40+
float e1 = expf(pp_ - p_);
41+
float e2 = expf(ww - p_);
42+
float aa_ = aa[i];
43+
float bb_ = bb[i];
44+
float v_ = v[i];
45+
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2))));
46+
ww = t_decay[i] + pp_;
47+
float k_ = k[i];
48+
p_ = (ww > k_) ? ww : k_;
49+
e1 = expf(ww - p_);
50+
e2 = expf(k_ - p_);
51+
t1[i] = e1 * aa_ + e2 * v_;
52+
t2[i] = e1 * bb_ + e2;
53+
p[i] = p_;
54+
}
55+
};
56+
57+
/*
58+
Equivalent Python code:
59+
kx = xx * k_mix + sx * (1 - k_mix)
60+
vx = xx * v_mix + sx * (1 - v_mix)
61+
rx = xx * r_mix + sx * (1 - r_mix)
62+
*/
63+
64+
struct Mix {
65+
const half *xx;
66+
const half *sx;
67+
const half *k_mix;
68+
const half *v_mix;
69+
const half *r_mix;
70+
/* out */ half *kx;
71+
/* out */ half *vx;
72+
/* out */ half *rx;
73+
74+
__device__ void operator()(int i) const {
75+
half xx_ = xx[i];
76+
half sx_ = sx[i];
77+
half k_mix_ = k_mix[i];
78+
half v_mix_ = v_mix[i];
79+
half r_mix_ = r_mix[i];
80+
kx[i] = __hadd(__hmul(xx_, k_mix_),
81+
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
82+
vx[i] = __hadd(__hmul(xx_, v_mix_),
83+
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
84+
rx[i] = __hadd(__hmul(xx_, r_mix_),
85+
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
86+
}
87+
};
88+
89+
using torch::Tensor;
90+
91+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
92+
93+
Tensor att_one(Tensor x, Tensor ln_w, Tensor ln_b, Tensor sx, Tensor k_mix,
94+
Tensor v_mix, Tensor r_mix, Tensor kw,
95+
/* imm */ Tensor kx, Tensor vw, /* imm */ Tensor vx, Tensor rw,
96+
/* imm */ Tensor rx, Tensor ow, Tensor t_first,
97+
/* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb,
98+
Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r,
99+
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
100+
/* out */ Tensor t2, /* out */ Tensor p) {
101+
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
102+
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
103+
data_ptr<half>(k_mix), data_ptr<half>(v_mix),
104+
data_ptr<half>(r_mix), data_ptr<half>(kx),
105+
data_ptr<half>(vx), data_ptr<half>(rx)},
106+
x.numel());
107+
108+
gemm_fp16_cublas(kx, kw, k);
109+
gemm_fp16_cublas(vx, vw, v);
110+
gemm_fp16_cublas(rx, rw, r);
111+
at::sigmoid_(r);
112+
113+
element_wise(WkvForwardOne{data_ptr<float>(t_first), data_ptr<float>(k),
114+
data_ptr<float>(pp), data_ptr<float>(aa),
115+
data_ptr<float>(bb), data_ptr<float>(t_decay),
116+
data_ptr<float>(v), data_ptr<float>(t1),
117+
data_ptr<float>(t2), data_ptr<float>(p),
118+
data_ptr<half>(r)},
119+
x.numel());
120+
121+
gemm_fp16_cublas(r, ow, x_plus_out);
122+
x_plus_out += x;
123+
return xx;
124+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#include "ATen/ATen.h"
2+
#include <cuda_fp16.h>
3+
#include <cuda_runtime.h>
4+
#include <torch/extension.h>
5+
6+
#include "util.h"
7+
#include "element_wise.h"
8+
9+
using torch::Tensor;
10+
11+
void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
12+
void gemm_fp16_cublas(const void *a, const void *b, void *c, int m,
13+
int n, int k, bool output_fp32);
14+
15+
__global__ void kernel_wkv_forward_new(
16+
const int B, const int T, const int C, const float *__restrict__ const _w,
17+
const float *__restrict__ const _u, const float *__restrict__ const _k,
18+
const float *__restrict__ const _v, const half *__restrict__ const r,
19+
half *__restrict__ const _y, float *__restrict__ const _aa,
20+
float *__restrict__ const _bb, float *__restrict__ const _pp) {
21+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
22+
const int _b = idx / C;
23+
const int _c = idx % C;
24+
const int _offset = _b * T * C + _c;
25+
const int _state_offset = _b * C + _c;
26+
27+
float u = _u[_c];
28+
float w = _w[_c];
29+
const float *__restrict__ const k = _k + _offset;
30+
const float *__restrict__ const v = _v + _offset;
31+
half *__restrict__ const y = _y + _offset;
32+
33+
float aa = _aa[_state_offset];
34+
float bb = _bb[_state_offset];
35+
float pp = _pp[_state_offset];
36+
for (int i = 0; i < T; i++) {
37+
const int ii = i * C;
38+
const float kk = k[ii];
39+
const float vv = v[ii];
40+
float ww = u + kk;
41+
float p = max(pp, ww);
42+
float e1 = exp(pp - p);
43+
float e2 = exp(ww - p);
44+
// y[ii] = __hmul(__float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)),
45+
// r[ii]);
46+
y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2));
47+
ww = w + pp;
48+
p = max(ww, kk);
49+
e1 = exp(ww - p);
50+
e2 = exp(kk - p);
51+
aa = e1 * aa + e2 * vv;
52+
bb = e1 * bb + e2;
53+
pp = p;
54+
}
55+
_aa[_state_offset] = aa;
56+
_bb[_state_offset] = bb;
57+
_pp[_state_offset] = pp;
58+
}
59+
60+
void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k,
61+
float *v, half *r, half *y, float *aa, float *bb,
62+
float *pp) {
63+
dim3 threadsPerBlock(min(C, 32));
64+
assert(B * C % threadsPerBlock.x == 0);
65+
dim3 numBlocks(B * C / threadsPerBlock.x);
66+
kernel_wkv_forward_new<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, r,
67+
y, aa, bb, pp);
68+
}
69+
70+
__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix,
71+
const half *v_mix, const half *r_mix,
72+
const int outer_size, const int inner_size, half *kx,
73+
half *vx, half *rx) {
74+
for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size;
75+
idx2 += blockDim.x * gridDim.x) {
76+
half k_mix_ = k_mix[idx2];
77+
half v_mix_ = v_mix[idx2];
78+
half r_mix_ = r_mix[idx2];
79+
for (int row = 0; row < outer_size; ++row) {
80+
int idx1 = row * inner_size + idx2;
81+
half xx_ = xx[idx1];
82+
half sx_ = sx[idx1];
83+
kx[idx1] = __hadd(__hmul(xx_, k_mix_),
84+
__hmul(sx_, __hsub(__float2half(1), k_mix_)));
85+
vx[idx1] = __hadd(__hmul(xx_, v_mix_),
86+
__hmul(sx_, __hsub(__float2half(1), v_mix_)));
87+
rx[idx1] = __hadd(__hmul(xx_, r_mix_),
88+
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
89+
}
90+
}
91+
}
92+
93+
void att_mix(const half *xx, const half *sx, const half *k_mix,
94+
const half *v_mix, const half *r_mix, const int outer_size,
95+
const int inner_size, half *kx, half *vx, half *rx) {
96+
// 256 is good enough on most GPUs
97+
const int32_t BLOCK_SIZE = 256;
98+
assert(inner_size % BLOCK_SIZE == 0);
99+
_att_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>(
100+
xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx);
101+
}
102+
103+
struct InplaceSigmoid {
104+
__device__ __forceinline__ half operator()(int i) const {
105+
ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i]))));
106+
}
107+
half *ptr;
108+
};
109+
110+
struct InplaceMul {
111+
__device__ __forceinline__ half operator()(int i) const {
112+
y[i] = __hmul(x[i], y[i]);
113+
}
114+
half *y;
115+
half *x;
116+
};
117+
118+
/*
119+
Equivalent Python code:
120+
121+
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
122+
sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
123+
kx = xx * k_mix + sx * (1 - k_mix)
124+
vx = xx * v_mix + sx * (1 - v_mix)
125+
rx = xx * r_mix + sx * (1 - r_mix)
126+
127+
r = torch.sigmoid(gemm(rx, rw))
128+
k = gemm(kx, kw, output_dtype=torch.float32)
129+
v = gemm(vx, vw, output_dtype=torch.float32)
130+
131+
T = x.shape[0]
132+
for t in range(T):
133+
kk = k[t]
134+
vv = v[t]
135+
ww = t_first + kk
136+
p = torch.maximum(pp, ww)
137+
e1 = torch.exp(pp - p)
138+
e2 = torch.exp(ww - p)
139+
sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
140+
ww = t_decay + pp
141+
p = torch.maximum(ww, kk)
142+
e1 = torch.exp(ww - p)
143+
e2 = torch.exp(kk - p)
144+
aa = e1 * aa + e2 * vv
145+
bb = e1 * bb + e2
146+
pp = p
147+
out = gemm(r * sx, ow)
148+
return x + out, xx[-1,:], aa, bb, pp
149+
*/
150+
Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
151+
Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
152+
Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb,
153+
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) {
154+
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
155+
sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0);
156+
char* buf_ptr = (char*)buf.data_ptr();
157+
half* kx = (half*)buf_ptr;
158+
half* vx = kx + x.numel();
159+
half* rx = vx + x.numel();
160+
half* wkv_y = rx + x.numel();
161+
att_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix),
162+
data_ptr<half>(v_mix), data_ptr<half>(r_mix), xx.size(0), xx.size(1),
163+
kx, vx, rx);
164+
float* k = reinterpret_cast<float*>(wkv_y + x.numel());
165+
float* v = k + x.size(0) * kw.size(1);
166+
half* r = reinterpret_cast<half*>(v + x.size(0) * vw.size(1));
167+
168+
gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true);
169+
gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true);
170+
gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false);
171+
element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1));
172+
cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr<float>(t_decay),
173+
data_ptr<float>(t_first), k, v, r,
174+
wkv_y, data_ptr<float>(aa),
175+
data_ptr<float>(bb), data_ptr<float>(pp));
176+
element_wise(InplaceMul{wkv_y, r}, x.numel());
177+
gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false);
178+
x_plus_out += x;
179+
return xx;
180+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include <cassert>
2+
#include <cstddef>
3+
#include <cstdint>
4+
5+
template <typename Func> __global__ void _element_wise(Func func, int n) {
6+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
7+
i += blockDim.x * gridDim.x) {
8+
func(i);
9+
}
10+
}
11+
12+
// NOTE: packed data type (e.g. float4) is a overkill for current sizes
13+
// (4096 in 7B model and 768 in 0.1B model),
14+
// and is not faster than the plain float version.
15+
template <typename Func>
16+
void element_wise(Func func, int n) {
17+
// 256 is good enough on most GPUs
18+
const int32_t BLOCK_SIZE = 256;
19+
assert(n % BLOCK_SIZE == 0);
20+
_element_wise<<<n / BLOCK_SIZE, BLOCK_SIZE>>>(func, n);
21+
}

0 commit comments

Comments
 (0)