Skip to content

Commit 4651f5e

Browse files
zhiyuan1iMollySophia
authored andcommitted
rwkv_wkv6 vulkan shader
1 parent 5555c0c commit 4651f5e

File tree

2 files changed

+260
-1
lines changed

2 files changed

+260
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ struct vk_device_struct {
240240
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
241241
vk_pipeline pipeline_timestep_embedding_f32;
242242
vk_pipeline pipeline_pool2d_f32;
243+
vk_pipeline pipeline_rwkv_wkv6_f32;
243244

244245
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
245246
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -523,6 +524,15 @@ struct vk_op_pool2d_push_constants {
523524
int32_t p0; int32_t p1;
524525
};
525526

527+
528+
struct vk_op_rwkv_wkv6_push_constants {
529+
uint32_t B; // Batch size (原n_seqs)
530+
uint32_t T; // Sequence length
531+
uint32_t C; // Total channels
532+
uint32_t H; // Number of heads (原HEADS)
533+
};
534+
535+
526536
// Allow pre-recording command buffers
527537
struct vk_staging_memcpy {
528538
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1942,6 +1952,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
19421952

19431953
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
19441954

1955+
ggml_vk_create_pipeline(
1956+
device,
1957+
device->pipeline_rwkv_wkv6_f32,
1958+
"rwkv_wkv6_f32",
1959+
rwkv_wkv6_f32_len,
1960+
rwkv_wkv6_f32_data,
1961+
"main",
1962+
7,
1963+
sizeof(vk_op_rwkv_wkv6_push_constants),
1964+
{64, 1, 1}, // work group
1965+
{device->subgroup_size},
1966+
1
1967+
);
1968+
19451969
for (auto &c : compiles) {
19461970
c.wait();
19471971
}
@@ -4917,6 +4941,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
49174941
return ctx->device->pipeline_pool2d_f32;
49184942
}
49194943
return nullptr;
4944+
case GGML_OP_RWKV_WKV6:
4945+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4946+
return ctx->device->pipeline_rwkv_wkv6_f32;
4947+
}
4948+
return nullptr;
49204949
case GGML_OP_LEAKY_RELU:
49214950
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
49224951
return ctx->device->pipeline_leaky_relu_f32;
@@ -5319,6 +5348,127 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
53195348
}, dryrun);
53205349
}
53215350

5351+
5352+
5353+
template<typename PC>
5354+
static void ggml_vk_op_f32_rwkv6(
5355+
ggml_backend_vk_context * ctx,
5356+
vk_context& subctx,
5357+
ggml_tensor * dst,
5358+
const PC&& pc,
5359+
bool dryrun = false) {
5360+
5361+
// Get source tensors
5362+
const ggml_tensor * k = dst->src[0]; // keys
5363+
const ggml_tensor * v = dst->src[1]; // values
5364+
const ggml_tensor * r = dst->src[2]; // reset gates
5365+
const ggml_tensor * tf = dst->src[3]; // time first
5366+
const ggml_tensor * td = dst->src[4]; // time decay
5367+
const ggml_tensor * state = dst->src[5]; // states
5368+
5369+
VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", "
5370+
<< tf << ", " << td << ", " << state << ", " << dst << ")");
5371+
5372+
// Verify input types
5373+
GGML_ASSERT(!ggml_is_quantized(k->type));
5374+
GGML_ASSERT(!ggml_is_quantized(v->type));
5375+
GGML_ASSERT(!ggml_is_quantized(r->type));
5376+
GGML_ASSERT(!ggml_is_quantized(tf->type));
5377+
GGML_ASSERT(!ggml_is_quantized(td->type));
5378+
GGML_ASSERT(!ggml_is_quantized(state->type));
5379+
GGML_ASSERT(dst->buffer != nullptr);
5380+
5381+
// Get pipeline
5382+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5383+
GGML_ASSERT(pipeline != nullptr);
5384+
5385+
if (dryrun) {
5386+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5387+
return;
5388+
}
5389+
5390+
// Get buffer contexts
5391+
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5392+
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5393+
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5394+
ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5395+
ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5396+
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5397+
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
5398+
5399+
// Get device buffers
5400+
vk_buffer d_D = dst_buf_ctx->dev_buffer;
5401+
vk_buffer d_K = k_buf_ctx->dev_buffer;
5402+
vk_buffer d_V = v_buf_ctx->dev_buffer;
5403+
vk_buffer d_R = r_buf_ctx->dev_buffer;
5404+
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
5405+
vk_buffer d_TD = td_buf_ctx->dev_buffer;
5406+
vk_buffer d_State = state_buf_ctx->dev_buffer;
5407+
5408+
// Calculate buffer offsets
5409+
const uint64_t k_offset = vk_tensor_offset(k);
5410+
const uint64_t v_offset = vk_tensor_offset(v);
5411+
const uint64_t r_offset = vk_tensor_offset(r);
5412+
const uint64_t tf_offset = vk_tensor_offset(tf);
5413+
const uint64_t td_offset = vk_tensor_offset(td);
5414+
const uint64_t state_offset = vk_tensor_offset(state);
5415+
const uint64_t dst_offset = vk_tensor_offset(dst);
5416+
5417+
// Calculate buffer sizes
5418+
const uint64_t k_size = ggml_nbytes(k);
5419+
const uint64_t v_size = ggml_nbytes(v);
5420+
const uint64_t r_size = ggml_nbytes(r);
5421+
const uint64_t tf_size = ggml_nbytes(tf);
5422+
const uint64_t td_size = ggml_nbytes(td);
5423+
const uint64_t state_size = ggml_nbytes(state);
5424+
const uint64_t dst_size = ggml_nbytes(dst);
5425+
5426+
// Set work elements based on tensor dimensions
5427+
std::array<uint32_t, 3> elements = {
5428+
(uint32_t)(pc.B*pc.H), // B * H workgroups
5429+
1, // 每个workgroup 64个线程
5430+
1
5431+
};
5432+
5433+
// Synchronize buffers and dispatch compute pipeline
5434+
ggml_vk_sync_buffers(subctx);
5435+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
5436+
vk_subbuffer{ d_K, k_offset, k_size },
5437+
vk_subbuffer{ d_V, v_offset, v_size },
5438+
vk_subbuffer{ d_R, r_offset, r_size },
5439+
vk_subbuffer{ d_TF, tf_offset, tf_size },
5440+
vk_subbuffer{ d_TD, td_offset, td_size },
5441+
vk_subbuffer{ d_State, state_offset, state_size },
5442+
vk_subbuffer{ d_D, dst_offset, dst_size }
5443+
}, sizeof(PC), &pc, elements);
5444+
}
5445+
5446+
static void ggml_vk_rwkv_wkv6(
5447+
ggml_backend_vk_context * ctx,
5448+
vk_context& subctx,
5449+
ggml_tensor * dst,
5450+
bool dryrun = false) {
5451+
5452+
// Extract dimensions from tensors
5453+
const size_t T = dst->src[0]->ne[3]; // Sequence length
5454+
const size_t C = dst->ne[0]; // Channel dimension
5455+
const size_t HEADS = dst->src[0]->ne[2]; // Number of heads
5456+
const size_t n_seqs = dst->src[5]->ne[1]; // Batch size
5457+
5458+
// Call implementation with push constants
5459+
ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
5460+
ctx, subctx, dst,
5461+
{
5462+
(uint32_t)n_seqs, // B
5463+
(uint32_t)T, // T
5464+
(uint32_t)C, // C
5465+
(uint32_t)HEADS, // H
5466+
},
5467+
dryrun
5468+
);
5469+
}
5470+
5471+
53225472
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
53235473
int * op_params = (int *)dst->op_params;
53245474

@@ -6464,6 +6614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
64646614
case GGML_OP_IM2COL:
64656615
case GGML_OP_TIMESTEP_EMBEDDING:
64666616
case GGML_OP_POOL_2D:
6617+
case GGML_OP_RWKV_WKV6:
64676618
case GGML_OP_LEAKY_RELU:
64686619
case GGML_OP_FLASH_ATTN_EXT:
64696620
break;
@@ -6663,6 +6814,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
66636814
case GGML_OP_FLASH_ATTN_EXT:
66646815
ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
66656816

6817+
break;
6818+
6819+
case GGML_OP_RWKV_WKV6:
6820+
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
6821+
66666822
break;
66676823
default:
66686824
return false;
@@ -6743,6 +6899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
67436899
case GGML_OP_IM2COL:
67446900
case GGML_OP_TIMESTEP_EMBEDDING:
67456901
case GGML_OP_POOL_2D:
6902+
case GGML_OP_RWKV_WKV6:
67466903
case GGML_OP_LEAKY_RELU:
67476904
case GGML_OP_REPEAT:
67486905
buf = tensor->buffer;
@@ -7610,6 +7767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
76107767
case GGML_OP_IM2COL:
76117768
case GGML_OP_TIMESTEP_EMBEDDING:
76127769
case GGML_OP_POOL_2D:
7770+
case GGML_OP_RWKV_WKV6:
76137771
case GGML_OP_LEAKY_RELU:
76147772
return true;
76157773
default:
@@ -8186,7 +8344,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
81868344
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
81878345
const float * op_params = (const float *)tensor->op_params;
81888346
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
8189-
} else {
8347+
}
8348+
// else if (tensor->op == GGML_OP_RWKV_WKV6) {
8349+
// tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8350+
// tensor->src[4], tensor->src[5]);
8351+
// }
8352+
else {
81908353
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
81918354
GGML_ABORT("fatal error");
81928355
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#version 450
2+
3+
4+
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
5+
6+
layout(push_constant) uniform Parameters {
7+
uint B; // Batch size
8+
uint T; // Sequence length
9+
uint C; // Total number of channels
10+
uint H; // Number of heads
11+
};
12+
13+
layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; };
14+
layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; };
15+
layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; };
16+
layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; };
17+
layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; };
18+
layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; };
19+
layout(set = 0, binding = 6) buffer DstBuf { float dst[]; };
20+
21+
shared float _k[64], _r[64], _tf[64], _td[64];
22+
23+
void main() {
24+
const uint head_size = 64;
25+
const uint batch_id = gl_WorkGroupID.x / H;
26+
const uint head_id = gl_WorkGroupID.x % H;
27+
const uint tid = gl_LocalInvocationID.x;
28+
29+
const uint state_size = C * head_size;
30+
const uint n_seq_tokens = T / B;
31+
32+
if (tid >= head_size || batch_id >= B || head_id >= H) {
33+
return;
34+
}
35+
36+
// Load state
37+
float state[64]; // Use fixed size matching head_size
38+
for (uint i = 0; i < head_size; i++) {
39+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
40+
+ i * head_size + tid];
41+
}
42+
43+
_k[tid] = 0.0;
44+
_r[tid] = 0.0;
45+
_td[tid] = 0.0;
46+
barrier();
47+
_tf[tid] = tf[head_id * head_size + tid];
48+
barrier();
49+
50+
51+
// Main loop
52+
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
53+
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
54+
55+
for (uint t = start_t; t < end_t; t += C) {
56+
barrier();
57+
_k[tid] = k[t];
58+
_r[tid] = r[t];
59+
_td[tid] = td[t];
60+
barrier();
61+
62+
const float v_val = v[t];
63+
float y = 0.0;
64+
65+
for (uint j = 0; j < head_size; j += 4) {
66+
// Load values in blocks of 4
67+
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
68+
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
69+
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
70+
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
71+
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
72+
73+
// Compute kv products
74+
vec4 kv = k_vec * v_val;
75+
76+
// Accumulate results
77+
vec4 temp = tf_vec * kv + s_vec;
78+
y += dot(r_vec, temp);
79+
80+
// Update state
81+
s_vec = s_vec * td_vec + kv;
82+
state[j] = s_vec.x;
83+
state[j+1] = s_vec.y;
84+
state[j+2] = s_vec.z;
85+
state[j+3] = s_vec.w;
86+
}
87+
88+
dst[t] = y;
89+
}
90+
91+
// Write back state
92+
for (uint i = 0; i < head_size; i++) {
93+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
94+
+ i * head_size + tid] = state[i];
95+
}
96+
}

0 commit comments

Comments
 (0)