Skip to content

Commit 15eb680

Browse files
committed
reduce more cpu overhead
Signed-off-by: daquexian <daquexian566@gmail.com>
1 parent aec55ba commit 15eb680

File tree

3 files changed

+66
-54
lines changed

3 files changed

+66
-54
lines changed

rwkv_pip_package/src/rwkv/cuda/att_one_v5.cu

Lines changed: 48 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ struct Mix {
5454
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
5555
}
5656
};
57+
58+
struct ToHalf {
59+
const float *x;
60+
half *y;
61+
__device__ void operator()(int i) const { y[i] = __float2half(x[i]); }
62+
};
63+
64+
struct InplaceAdd {
65+
__device__ __forceinline__ half operator()(int i) const {
66+
y[i] = __hadd(x[i], y[i]);
67+
}
68+
half *y;
69+
half *x;
70+
};
5771
} // namespace
5872

5973
using torch::Tensor;
@@ -64,50 +78,44 @@ void gemm_cublas(const void *a, const void *b, void *c, int batch, int ori_m,
6478
at::ScalarType torch_output_dtype);
6579

6680
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
67-
Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
68-
/* imm */ Tensor kvrx, Tensor kvrw, Tensor ow, Tensor t_first,
69-
Tensor t_decay, /* imm */ Tensor kvr, /* imm */ Tensor a,
70-
/* imm */ Tensor buf,
71-
/* imm */ Tensor s1,
72-
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
81+
Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
82+
Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
83+
Tensor buf, /* out */ Tensor s2_t,
84+
/* out */ Tensor x_plus_out_t) {
7385
const int x_numel = x.numel();
7486
Tensor xx = at::layer_norm(x, {x_numel}, ln_w, ln_b);
87+
int H = t_decay.size(0);
88+
int S = x_numel / H;
89+
char *buf_ptr = (char *)buf.data_ptr();
90+
half *kvrx = (half *)buf_ptr;
91+
float *kvr = (float *)(kvrx + 3 * x_numel);
92+
float *a = kvr + 3 * x_numel;
93+
half *tmp2 = (half *)(a + H * S * S);
94+
float *s1 = (float *)(tmp2 + x_numel);
95+
float *s2 = data_ptr<float>(s2_t);
96+
half *x_plus_out = data_ptr<half>(x_plus_out_t);
97+
7598
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
76-
data_ptr<half>(kvr_mix), static_cast<int>(x_numel),
77-
data_ptr<half>(kvrx)},
99+
data_ptr<half>(kvr_mix), static_cast<int>(x_numel), kvrx},
78100
x_numel);
79101

80-
int H = t_decay.size(0);
81-
int S = x_numel / H;
82-
// gemm_cublas_tensor(at::unsqueeze(kvrx, 1), kvrw, kvr);
83-
gemm_cublas(data_ptr<half>(kvrx), data_ptr<half>(kvrw), data_ptr<float>(kvr),
84-
3, 1, x_numel, x_numel, at::kHalf, at::kFloat);
85-
float* k = data_ptr<float>(kvr);
86-
float* v = k + x_numel;
87-
float* r = v + x_numel;
88-
// Tensor k = at::reshape(kvr[0], {H, S, 1});
89-
// Tensor v = at::reshape(kvr[1], {H, 1, S});
90-
// Tensor r = at::reshape(kvr[2], {H, 1, S});
91-
92-
// gemm_cublas_tensor(k, v, a);
93-
gemm_cublas(k, v, data_ptr<float>(a), H, S, S, 1, at::kFloat, at::kFloat);
94-
// s1 = t_first * a + s
95-
// s2 = a + t_decay * s
96-
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
97-
data_ptr<float>(a), data_ptr<float>(s),
98-
static_cast<int32_t>(a.size(1) * a.size(2)),
99-
data_ptr<float>(s1), data_ptr<float>(s2)},
100-
a.numel());
101-
102-
// gemm_cublas_tensor(r, s1, buf);
103-
gemm_cublas(r, data_ptr<float>(s1), data_ptr<float>(buf), H, 1, S, S,
104-
at::kFloat, at::kFloat);
105-
buf = at::group_norm(buf, H, lx_w, lx_b);
106-
buf = at::_cast_Half(buf);
107-
108-
// gemm_cublas_tensor(buf, ow, x_plus_out);
109-
gemm_cublas(data_ptr<half>(buf), data_ptr<half>(ow), data_ptr<half>(x_plus_out),
110-
1, 1, x_numel, x_numel, at::kHalf, at::kHalf);
111-
x_plus_out += x;
102+
gemm_cublas(kvrx, data_ptr<half>(kvrw), kvr, 3, 1, x_numel, x_numel,
103+
at::kHalf, at::kFloat);
104+
float *k = kvr;
105+
float *v = k + x_numel;
106+
float *r = v + x_numel;
107+
108+
gemm_cublas(k, v, a, H, S, S, 1, at::kFloat, at::kFloat);
109+
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay), a,
110+
data_ptr<float>(s), static_cast<int32_t>(S * S), s1, s2},
111+
H * S * S);
112+
113+
gemm_cublas(r, s1, data_ptr<float>(tmp), H, 1, S, S, at::kFloat, at::kFloat);
114+
tmp = at::group_norm(tmp, H, lx_w, lx_b);
115+
element_wise(ToHalf{data_ptr<float>(tmp), tmp2}, tmp.numel());
116+
117+
gemm_cublas(tmp2, data_ptr<half>(ow), x_plus_out, 1, 1, x_numel, x_numel,
118+
at::kHalf, at::kHalf);
119+
element_wise(InplaceAdd{x_plus_out, data_ptr<half>(x)}, x.numel());
112120
return xx;
113121
}

rwkv_pip_package/src/rwkv/cuda/wrapper.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,10 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
137137
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);
138138

139139
Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
140-
Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
141-
/* imm */ Tensor kvrx, Tensor kvrw, Tensor ow, Tensor t_first,
142-
Tensor t_decay, /* imm */ Tensor kvr, /* imm */ Tensor a,
143-
/* imm */ Tensor buf,
144-
/* imm */ Tensor s1,
145-
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
140+
Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
141+
Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
142+
Tensor buf, /* out */ Tensor s2_t,
143+
/* out */ Tensor x_plus_out_t);
146144

147145
Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
148146
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,

rwkv_pip_package/src/rwkv/model.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -746,21 +746,27 @@ def cuda_att_one_fp16(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix,
746746

747747
@MyFunction
748748
def cuda_att_one_v5_fp16(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, t_decay, t_first, kvrw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
749-
kvrx = torch.empty((3, x.numel()), dtype=x.dtype, device=x.device)
750-
751749
H = t_decay.shape[0]
752750
S = x.shape[-1] // H
753-
754-
kvr = torch.empty((3, 1, x.shape[-1]), dtype=torch.float32, device=x.device)
755-
a = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
756-
buf = torch.empty((1, x.shape[-1]), dtype=torch.float32, device=x.device)
757-
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
751+
tmp = torch.empty((1, x.shape[-1]), dtype=torch.float32, device=x.device)
752+
buf = torch.empty((3 * x.numel() * 2 + 3 * x.numel() * 4 + H * S * S * 4 + x.numel() * 2 + H * S * S * 4,), dtype=torch.int8, device=x.device)
753+
# two outputs
758754
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
759755
x_plus_out = torch.empty_like(x)
756+
757+
# kvrx = torch.empty((3, x.numel()), dtype=x.dtype, device=x.device)
758+
# kvr = torch.empty((3, 1, x.shape[-1]), dtype=torch.float32, device=x.device)
759+
# a = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
760+
# tmp2 = torch.empty((1, x.shape[-1]), dtype=torch.float16, device=x.device)
761+
# s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
762+
# s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
763+
# x_plus_out = torch.empty_like(x)
764+
760765
# import pdb; pdb.set_trace()
761766

762-
xx = torch.ops.rwkv.att_one_v5(x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, kvrx, kvrw, ow, t_first, t_decay, kvr, a, buf, s1, x_plus_out, s2) # type: ignore[reportGeneralTypeIssues]
767+
xx = torch.ops.rwkv.att_one_v5(x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, kvrw, ow, t_first, t_decay, tmp, buf, s2, x_plus_out) # type: ignore[reportGeneralTypeIssues]
763768

769+
# import pdb; pdb.set_trace()
764770
return x_plus_out, xx, s2
765771

766772
@MyFunction

0 commit comments

Comments
 (0)