Skip to content

Commit 77fe4fd

Browse files
committed
RWKV_WKV6 Vulkan op tests passed
Signed-off-by: Molly Sophia <[email protected]>
1 parent 4651f5e commit 77fe4fd

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
19611961
"main",
19621962
7,
19631963
sizeof(vk_op_rwkv_wkv6_push_constants),
1964-
{64, 1, 1}, // work group
1964+
{1, 1, 1}, // work group
19651965
{device->subgroup_size},
19661966
1
19671967
);
@@ -8344,11 +8344,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83448344
} else if (tensor->op == GGML_OP_LEAKY_RELU) {
83458345
const float * op_params = (const float *)tensor->op_params;
83468346
tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
8347+
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
8348+
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8349+
tensor->src[4], tensor->src[5]);
83478350
}
8348-
// else if (tensor->op == GGML_OP_RWKV_WKV6) {
8349-
// tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8350-
// tensor->src[4], tensor->src[5]);
8351-
// }
83528351
else {
83538352
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
83548353
GGML_ABORT("fatal error");

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@ void process_shaders() {
479479

480480
string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
481481

482+
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}}));
483+
482484
for (auto &c : compiles) {
483485
c.wait();
484486
}

0 commit comments

Comments
 (0)