Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
return s;
};

bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);

CREATE_BINARY(add, , {0})
Expand Down Expand Up @@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
#undef CREATE_UNARY

#define CREATE_GLU(name) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
}

CREATE_GLU(geglu)
CREATE_GLU(reglu)
Expand Down
6 changes: 1 addition & 5 deletions ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#version 450

#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16

#include "rte.comp"
#include "types.comp"

#if defined(SET_ROWS) && QUANT_K == 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require

#include "rte.comp"

layout (push_constant) uniform parameter
{
uint ne;
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#extension GL_EXT_shader_16bit_storage : require

#include "rte.comp"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
Expand Down
5 changes: 1 addition & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
#version 450

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable
#extension GL_EXT_control_flow_attributes : require

#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
#include "rte.comp"

layout (push_constant) uniform parameter
{
Expand Down
5 changes: 1 addition & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
#include "types.comp"

#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_spirv_intrinsics: enable

#if RTE16
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif
#include "rte.comp"

layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;

Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/rte.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16
85 changes: 69 additions & 16 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,10 @@ void process_shaders() {
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
for (auto rte : {false, true}) {
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}
}
}
}
Expand Down Expand Up @@ -592,16 +594,19 @@ void process_shaders() {
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : "";
string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
}

string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
Expand Down Expand Up @@ -709,11 +714,59 @@ void write_output_files() {
std::remove(path.c_str());
}
}

std::string suffixes[2] = {"_f32", "_f16"};
for (const char *op : {"add", "sub", "mul", "div"}) {
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
for (uint32_t t0 = 0; t0 < 2; ++t0) {
if (t0 == 0) {
data += "{";
len += "{";
}
for (uint32_t t1 = 0; t1 < 2; ++t1) {
if (t1 == 0) {
data += "{";
len += "{";
}
for (uint32_t t2 = 0; t2 < 2; ++t2) {
if (t2 == 0) {
data += "{";
len += "{";
}
for (uint32_t rte = 0; rte < 2; ++rte) {
if (rte == 0) {
data += "{";
len += "{";
}
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
data += "_data,";
len += "_len,";
if (rte == 1) {
data += "}, ";
len += "}, ";
}
}
if (t2 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t1 == 1) {
data += "}, ";
len += "}, ";
}
}
if (t0 == 1) {
data += "};\n";
len += "};\n";
}
}
fprintf(src, data.c_str());
fprintf(src, len.c_str());
}
fclose(hdr);
fclose(src);
Expand Down
Loading