Skip to content

Commit 60bbd4e

Browse files
committed
Apply code format changes
Signed-off-by: Molly Sophia <[email protected]>
1 parent 77fe4fd commit 60bbd4e

File tree

3 files changed

+63
-117
lines changed

3 files changed

+63
-117
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -524,15 +524,13 @@ struct vk_op_pool2d_push_constants {
524524
int32_t p0; int32_t p1;
525525
};
526526

527-
528527
struct vk_op_rwkv_wkv6_push_constants {
529-
uint32_t B; // Batch size (原n_seqs)
530-
uint32_t T; // Sequence length
531-
uint32_t C; // Total channels
532-
uint32_t H; // Number of heads (原HEADS)
528+
uint32_t B;
529+
uint32_t T;
530+
uint32_t C;
531+
uint32_t H;
533532
};
534533

535-
536534
// Allow pre-recording command buffers
537535
struct vk_staging_memcpy {
538536
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1952,19 +1950,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
19521950

19531951
ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
19541952

1955-
ggml_vk_create_pipeline(
1956-
device,
1957-
device->pipeline_rwkv_wkv6_f32,
1958-
"rwkv_wkv6_f32",
1959-
rwkv_wkv6_f32_len,
1960-
rwkv_wkv6_f32_data,
1961-
"main",
1962-
7,
1963-
sizeof(vk_op_rwkv_wkv6_push_constants),
1964-
{1, 1, 1}, // work group
1965-
{device->subgroup_size},
1966-
1
1967-
);
1953+
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
19681954

19691955
for (auto &c : compiles) {
19701956
c.wait();
@@ -5348,28 +5334,14 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
53485334
}, dryrun);
53495335
}
53505336

5337+
static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5338+
const ggml_tensor * k = dst->src[0];
5339+
const ggml_tensor * v = dst->src[1];
5340+
const ggml_tensor * r = dst->src[2];
5341+
const ggml_tensor * tf = dst->src[3];
5342+
const ggml_tensor * td = dst->src[4];
5343+
const ggml_tensor * state = dst->src[5];
53515344

5352-
5353-
template<typename PC>
5354-
static void ggml_vk_op_f32_rwkv6(
5355-
ggml_backend_vk_context * ctx,
5356-
vk_context& subctx,
5357-
ggml_tensor * dst,
5358-
const PC&& pc,
5359-
bool dryrun = false) {
5360-
5361-
// Get source tensors
5362-
const ggml_tensor * k = dst->src[0]; // keys
5363-
const ggml_tensor * v = dst->src[1]; // values
5364-
const ggml_tensor * r = dst->src[2]; // reset gates
5365-
const ggml_tensor * tf = dst->src[3]; // time first
5366-
const ggml_tensor * td = dst->src[4]; // time decay
5367-
const ggml_tensor * state = dst->src[5]; // states
5368-
5369-
VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", "
5370-
<< tf << ", " << td << ", " << state << ", " << dst << ")");
5371-
5372-
// Verify input types
53735345
GGML_ASSERT(!ggml_is_quantized(k->type));
53745346
GGML_ASSERT(!ggml_is_quantized(v->type));
53755347
GGML_ASSERT(!ggml_is_quantized(r->type));
@@ -5378,7 +5350,6 @@ static void ggml_vk_op_f32_rwkv6(
53785350
GGML_ASSERT(!ggml_is_quantized(state->type));
53795351
GGML_ASSERT(dst->buffer != nullptr);
53805352

5381-
// Get pipeline
53825353
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
53835354
GGML_ASSERT(pipeline != nullptr);
53845355

@@ -5387,7 +5358,6 @@ static void ggml_vk_op_f32_rwkv6(
53875358
return;
53885359
}
53895360

5390-
// Get buffer contexts
53915361
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
53925362
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
53935363
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
@@ -5396,7 +5366,6 @@ static void ggml_vk_op_f32_rwkv6(
53965366
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
53975367
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
53985368

5399-
// Get device buffers
54005369
vk_buffer d_D = dst_buf_ctx->dev_buffer;
54015370
vk_buffer d_K = k_buf_ctx->dev_buffer;
54025371
vk_buffer d_V = v_buf_ctx->dev_buffer;
@@ -5405,7 +5374,6 @@ static void ggml_vk_op_f32_rwkv6(
54055374
vk_buffer d_TD = td_buf_ctx->dev_buffer;
54065375
vk_buffer d_State = state_buf_ctx->dev_buffer;
54075376

5408-
// Calculate buffer offsets
54095377
const uint64_t k_offset = vk_tensor_offset(k);
54105378
const uint64_t v_offset = vk_tensor_offset(v);
54115379
const uint64_t r_offset = vk_tensor_offset(r);
@@ -5414,7 +5382,6 @@ static void ggml_vk_op_f32_rwkv6(
54145382
const uint64_t state_offset = vk_tensor_offset(state);
54155383
const uint64_t dst_offset = vk_tensor_offset(dst);
54165384

5417-
// Calculate buffer sizes
54185385
const uint64_t k_size = ggml_nbytes(k);
54195386
const uint64_t v_size = ggml_nbytes(v);
54205387
const uint64_t r_size = ggml_nbytes(r);
@@ -5423,14 +5390,12 @@ static void ggml_vk_op_f32_rwkv6(
54235390
const uint64_t state_size = ggml_nbytes(state);
54245391
const uint64_t dst_size = ggml_nbytes(dst);
54255392

5426-
// Set work elements based on tensor dimensions
54275393
std::array<uint32_t, 3> elements = {
5428-
(uint32_t)(pc.B*pc.H), // B * H workgroups
5429-
1, // 每个workgroup 64个线程
5394+
(uint32_t)(pc.B * pc.H),
5395+
1,
54305396
1
54315397
};
54325398

5433-
// Synchronize buffers and dispatch compute pipeline
54345399
ggml_vk_sync_buffers(subctx);
54355400
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
54365401
vk_subbuffer{ d_K, k_offset, k_size },
@@ -5440,35 +5405,27 @@ static void ggml_vk_op_f32_rwkv6(
54405405
vk_subbuffer{ d_TD, td_offset, td_size },
54415406
vk_subbuffer{ d_State, state_offset, state_size },
54425407
vk_subbuffer{ d_D, dst_offset, dst_size }
5443-
}, sizeof(PC), &pc, elements);
5444-
}
5445-
5446-
static void ggml_vk_rwkv_wkv6(
5447-
ggml_backend_vk_context * ctx,
5448-
vk_context& subctx,
5449-
ggml_tensor * dst,
5450-
bool dryrun = false) {
5451-
5452-
// Extract dimensions from tensors
5453-
const size_t T = dst->src[0]->ne[3]; // Sequence length
5454-
const size_t C = dst->ne[0]; // Channel dimension
5455-
const size_t HEADS = dst->src[0]->ne[2]; // Number of heads
5456-
const size_t n_seqs = dst->src[5]->ne[1]; // Batch size
5457-
5458-
// Call implementation with push constants
5459-
ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
5408+
}, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
5409+
}
5410+
5411+
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5412+
const size_t seq_length = dst->src[0]->ne[3];
5413+
const size_t n_embed = dst->ne[0];
5414+
const size_t n_heads = dst->src[0]->ne[2];
5415+
const size_t n_seqs = dst->src[5]->ne[1];
5416+
5417+
ggml_vk_op_f32_rwkv6(
54605418
ctx, subctx, dst,
54615419
{
5462-
(uint32_t)n_seqs, // B
5463-
(uint32_t)T, // T
5464-
(uint32_t)C, // C
5465-
(uint32_t)HEADS, // H
5420+
(uint32_t)n_seqs,
5421+
(uint32_t)seq_length,
5422+
(uint32_t)n_embed,
5423+
(uint32_t)n_heads,
54665424
},
54675425
dryrun
54685426
);
54695427
}
54705428

5471-
54725429
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
54735430
int * op_params = (int *)dst->op_params;
54745431

@@ -8344,10 +8301,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83448301
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
83458302
const float * op_params = (const float *)tensor->op_params;
83468303
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
8347-
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
8304+
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
83488305
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
83498306
tensor->src[4], tensor->src[5]);
8350-
}
8307+
}
83518308
else {
83528309
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
83538310
GGML_ABORT("fatal error");

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ void process_shaders() {
479479

480480
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481481

482-
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}}));
482+
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
483483

484484
for (auto &c : compiles) {
485485
c.wait();
Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,85 @@
11
#version 450
22

3-
4-
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
3+
#define BLOCK_SIZE 64
4+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
55

66
layout(push_constant) uniform Parameters {
7-
uint B; // Batch size
8-
uint T; // Sequence length
9-
uint C; // Total number of channels
10-
uint H; // Number of heads
7+
uint B;
8+
uint T;
9+
uint C;
10+
uint H;
1111
};
1212

13-
layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; };
14-
layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; };
15-
layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; };
16-
layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; };
17-
layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; };
18-
layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; };
19-
layout(set = 0, binding = 6) buffer DstBuf { float dst[]; };
13+
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; };
14+
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; };
15+
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; };
16+
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; };
17+
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; };
18+
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; };
19+
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; };
2020

21-
shared float _k[64], _r[64], _tf[64], _td[64];
21+
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE];
2222

2323
void main() {
24-
const uint head_size = 64;
24+
const uint head_size = BLOCK_SIZE;
2525
const uint batch_id = gl_WorkGroupID.x / H;
2626
const uint head_id = gl_WorkGroupID.x % H;
2727
const uint tid = gl_LocalInvocationID.x;
28-
28+
2929
const uint state_size = C * head_size;
3030
const uint n_seq_tokens = T / B;
3131

3232
if (tid >= head_size || batch_id >= B || head_id >= H) {
3333
return;
3434
}
35-
36-
// Load state
37-
float state[64]; // Use fixed size matching head_size
35+
36+
A_TYPE state[BLOCK_SIZE];
3837
for (uint i = 0; i < head_size; i++) {
39-
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
38+
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
4039
+ i * head_size + tid];
4140
}
42-
43-
_k[tid] = 0.0;
44-
_r[tid] = 0.0;
45-
_td[tid] = 0.0;
41+
4642
barrier();
4743
_tf[tid] = tf[head_id * head_size + tid];
4844
barrier();
4945

50-
51-
// Main loop
5246
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
5347
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
54-
48+
5549
for (uint t = start_t; t < end_t; t += C) {
5650
barrier();
5751
_k[tid] = k[t];
5852
_r[tid] = r[t];
5953
_td[tid] = td[t];
6054
barrier();
61-
62-
const float v_val = v[t];
63-
float y = 0.0;
64-
55+
56+
const A_TYPE v_val = v[t];
57+
A_TYPE y = 0.0;
58+
6559
for (uint j = 0; j < head_size; j += 4) {
66-
// Load values in blocks of 4
6760
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
6861
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
6962
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
7063
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
7164
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
72-
73-
// Compute kv products
65+
7466
vec4 kv = k_vec * v_val;
75-
76-
// Accumulate results
67+
7768
vec4 temp = tf_vec * kv + s_vec;
7869
y += dot(r_vec, temp);
79-
80-
// Update state
70+
8171
s_vec = s_vec * td_vec + kv;
8272
state[j] = s_vec.x;
8373
state[j+1] = s_vec.y;
8474
state[j+2] = s_vec.z;
8575
state[j+3] = s_vec.w;
8676
}
87-
77+
8878
dst[t] = y;
8979
}
90-
91-
// Write back state
80+
9281
for (uint i = 0; i < head_size; i++) {
93-
dst[T * C + batch_id * state_size + head_id * head_size * head_size
82+
dst[T * C + batch_id * state_size + head_id * head_size * head_size
9483
+ i * head_size + tid] = state[i];
9584
}
96-
}
85+
}

0 commit comments

Comments
 (0)