|
| 1 | +#include "ATen/ATen.h" |
| 2 | +#include <cuda_fp16.h> |
| 3 | +#include <cuda_runtime.h> |
| 4 | +#include <torch/extension.h> |
| 5 | + |
| 6 | +#include "util.h" |
| 7 | +#include "element_wise.h" |
| 8 | + |
| 9 | +using torch::Tensor; |
| 10 | + |
| 11 | +void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); |
| 12 | +void gemm_fp16_cublas(const void *a, const void *b, void *c, int m, |
| 13 | + int n, int k, bool output_fp32); |
| 14 | + |
| 15 | +__global__ void kernel_wkv_forward_new( |
| 16 | + const int B, const int T, const int C, const float *__restrict__ const _w, |
| 17 | + const float *__restrict__ const _u, const float *__restrict__ const _k, |
| 18 | + const float *__restrict__ const _v, const half *__restrict__ const r, |
| 19 | + half *__restrict__ const _y, float *__restrict__ const _aa, |
| 20 | + float *__restrict__ const _bb, float *__restrict__ const _pp) { |
| 21 | + const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 22 | + const int _b = idx / C; |
| 23 | + const int _c = idx % C; |
| 24 | + const int _offset = _b * T * C + _c; |
| 25 | + const int _state_offset = _b * C + _c; |
| 26 | + |
| 27 | + float u = _u[_c]; |
| 28 | + float w = _w[_c]; |
| 29 | + const float *__restrict__ const k = _k + _offset; |
| 30 | + const float *__restrict__ const v = _v + _offset; |
| 31 | + half *__restrict__ const y = _y + _offset; |
| 32 | + |
| 33 | + float aa = _aa[_state_offset]; |
| 34 | + float bb = _bb[_state_offset]; |
| 35 | + float pp = _pp[_state_offset]; |
| 36 | + for (int i = 0; i < T; i++) { |
| 37 | + const int ii = i * C; |
| 38 | + const float kk = k[ii]; |
| 39 | + const float vv = v[ii]; |
| 40 | + float ww = u + kk; |
| 41 | + float p = max(pp, ww); |
| 42 | + float e1 = exp(pp - p); |
| 43 | + float e2 = exp(ww - p); |
| 44 | + // y[ii] = __hmul(__float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)), |
| 45 | + // r[ii]); |
| 46 | + y[ii] = __float2half((e1 * aa + e2 * vv) / (e1 * bb + e2)); |
| 47 | + ww = w + pp; |
| 48 | + p = max(ww, kk); |
| 49 | + e1 = exp(ww - p); |
| 50 | + e2 = exp(kk - p); |
| 51 | + aa = e1 * aa + e2 * vv; |
| 52 | + bb = e1 * bb + e2; |
| 53 | + pp = p; |
| 54 | + } |
| 55 | + _aa[_state_offset] = aa; |
| 56 | + _bb[_state_offset] = bb; |
| 57 | + _pp[_state_offset] = pp; |
| 58 | +} |
| 59 | + |
| 60 | +void cuda_wkv_forward_new(int B, int T, int C, float *w, float *u, float *k, |
| 61 | + float *v, half *r, half *y, float *aa, float *bb, |
| 62 | + float *pp) { |
| 63 | + dim3 threadsPerBlock(min(C, 32)); |
| 64 | + assert(B * C % threadsPerBlock.x == 0); |
| 65 | + dim3 numBlocks(B * C / threadsPerBlock.x); |
| 66 | + kernel_wkv_forward_new<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, r, |
| 67 | + y, aa, bb, pp); |
| 68 | +} |
| 69 | + |
| 70 | +__global__ void _att_mix(const half *xx, const half *sx, const half *k_mix, |
| 71 | + const half *v_mix, const half *r_mix, |
| 72 | + const int outer_size, const int inner_size, half *kx, |
| 73 | + half *vx, half *rx) { |
| 74 | + for (int idx2 = blockIdx.x * blockDim.x + threadIdx.x; idx2 < inner_size; |
| 75 | + idx2 += blockDim.x * gridDim.x) { |
| 76 | + half k_mix_ = k_mix[idx2]; |
| 77 | + half v_mix_ = v_mix[idx2]; |
| 78 | + half r_mix_ = r_mix[idx2]; |
| 79 | + for (int row = 0; row < outer_size; ++row) { |
| 80 | + int idx1 = row * inner_size + idx2; |
| 81 | + half xx_ = xx[idx1]; |
| 82 | + half sx_ = sx[idx1]; |
| 83 | + kx[idx1] = __hadd(__hmul(xx_, k_mix_), |
| 84 | + __hmul(sx_, __hsub(__float2half(1), k_mix_))); |
| 85 | + vx[idx1] = __hadd(__hmul(xx_, v_mix_), |
| 86 | + __hmul(sx_, __hsub(__float2half(1), v_mix_))); |
| 87 | + rx[idx1] = __hadd(__hmul(xx_, r_mix_), |
| 88 | + __hmul(sx_, __hsub(__float2half(1), r_mix_))); |
| 89 | + } |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +void att_mix(const half *xx, const half *sx, const half *k_mix, |
| 94 | + const half *v_mix, const half *r_mix, const int outer_size, |
| 95 | + const int inner_size, half *kx, half *vx, half *rx) { |
| 96 | + // 256 is good enough on most GPUs |
| 97 | + const int32_t BLOCK_SIZE = 256; |
| 98 | + assert(inner_size % BLOCK_SIZE == 0); |
| 99 | + _att_mix<<<inner_size / BLOCK_SIZE, BLOCK_SIZE>>>( |
| 100 | + xx, sx, k_mix, v_mix, r_mix, outer_size, inner_size, kx, vx, rx); |
| 101 | +} |
| 102 | + |
| 103 | +struct InplaceSigmoid { |
| 104 | + __device__ __forceinline__ half operator()(int i) const { |
| 105 | + ptr[i] = __float2half(1.0 / (1.0 + exp(-__half2float(ptr[i])))); |
| 106 | + } |
| 107 | + half *ptr; |
| 108 | +}; |
| 109 | + |
| 110 | +struct InplaceMul { |
| 111 | + __device__ __forceinline__ half operator()(int i) const { |
| 112 | + y[i] = __hmul(x[i], y[i]); |
| 113 | + } |
| 114 | + half *y; |
| 115 | + half *x; |
| 116 | +}; |
| 117 | + |
| 118 | +/* |
| 119 | + Equivalent Python code: |
| 120 | +
|
| 121 | + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) |
| 122 | + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) |
| 123 | + kx = xx * k_mix + sx * (1 - k_mix) |
| 124 | + vx = xx * v_mix + sx * (1 - v_mix) |
| 125 | + rx = xx * r_mix + sx * (1 - r_mix) |
| 126 | +
|
| 127 | + r = torch.sigmoid(gemm(rx, rw)) |
| 128 | + k = gemm(kx, kw, output_dtype=torch.float32) |
| 129 | + v = gemm(vx, vw, output_dtype=torch.float32) |
| 130 | +
|
| 131 | + T = x.shape[0] |
| 132 | + for t in range(T): |
| 133 | + kk = k[t] |
| 134 | + vv = v[t] |
| 135 | + ww = t_first + kk |
| 136 | + p = torch.maximum(pp, ww) |
| 137 | + e1 = torch.exp(pp - p) |
| 138 | + e2 = torch.exp(ww - p) |
| 139 | + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) |
| 140 | + ww = t_decay + pp |
| 141 | + p = torch.maximum(ww, kk) |
| 142 | + e1 = torch.exp(ww - p) |
| 143 | + e2 = torch.exp(kk - p) |
| 144 | + aa = e1 * aa + e2 * vv |
| 145 | + bb = e1 * bb + e2 |
| 146 | + pp = p |
| 147 | + out = gemm(r * sx, ow) |
| 148 | + return x + out, xx[-1,:], aa, bb, pp |
| 149 | +*/ |
| 150 | +Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix, |
| 151 | + Tensor v_mix, Tensor r_mix, Tensor kw, Tensor vw, Tensor rw, |
| 152 | + Tensor ow, Tensor t_first, Tensor pp, Tensor aa, Tensor bb, |
| 153 | + Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out) { |
| 154 | + Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b); |
| 155 | + sx = at::cat({sx.unsqueeze(0), xx.slice(0, 0, -1)}, 0); |
| 156 | + char* buf_ptr = (char*)buf.data_ptr(); |
| 157 | + half* kx = (half*)buf_ptr; |
| 158 | + half* vx = kx + x.numel(); |
| 159 | + half* rx = vx + x.numel(); |
| 160 | + half* wkv_y = rx + x.numel(); |
| 161 | + att_mix(data_ptr<half>(xx), data_ptr<half>(sx), data_ptr<half>(k_mix), |
| 162 | + data_ptr<half>(v_mix), data_ptr<half>(r_mix), xx.size(0), xx.size(1), |
| 163 | + kx, vx, rx); |
| 164 | + float* k = reinterpret_cast<float*>(wkv_y + x.numel()); |
| 165 | + float* v = k + x.size(0) * kw.size(1); |
| 166 | + half* r = reinterpret_cast<half*>(v + x.size(0) * vw.size(1)); |
| 167 | + |
| 168 | + gemm_fp16_cublas(kx, kw.data_ptr(), k, x.size(0), kw.size(1), kw.size(0), true); |
| 169 | + gemm_fp16_cublas(vx, vw.data_ptr(), v, x.size(0), vw.size(1), vw.size(0), true); |
| 170 | + gemm_fp16_cublas(rx, rw.data_ptr(), r, x.size(0), rw.size(1), rw.size(0), false); |
| 171 | + element_wise(InplaceSigmoid{r}, x.size(0) * rw.size(1)); |
| 172 | + cuda_wkv_forward_new(1, x.size(0), x.size(1), data_ptr<float>(t_decay), |
| 173 | + data_ptr<float>(t_first), k, v, r, |
| 174 | + wkv_y, data_ptr<float>(aa), |
| 175 | + data_ptr<float>(bb), data_ptr<float>(pp)); |
| 176 | + element_wise(InplaceMul{wkv_y, r}, x.numel()); |
| 177 | + gemm_fp16_cublas(wkv_y, ow.data_ptr(), x_plus_out.data_ptr(), x.size(0), ow.size(1), ow.size(0), false); |
| 178 | + x_plus_out += x; |
| 179 | + return xx; |
| 180 | +} |
0 commit comments