Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ enum MatMulIdType {

namespace {

void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
void execute_command(const std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
HANDLE stderr_read, stderr_write;
Expand All @@ -99,7 +99,13 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
si.hStdOutput = stdout_write;
si.hStdError = stderr_write;

std::vector<char> cmd(command.begin(), command.end());
std::vector<char> cmd;
for (const auto& part : command) {
for (char c : part) {
cmd.push_back(c);
}
cmd.push_back(' ');
}
cmd.push_back('\0');

if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
Expand Down Expand Up @@ -138,14 +144,20 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
throw std::runtime_error("Failed to fork process");
}

std::vector<char*> argv;
for (const std::string& part : command) {
argv.push_back(const_cast<char *>(part.c_str()));
}
argv.push_back(nullptr);

if (pid == 0) {
close(stdout_pipe[0]);
close(stderr_pipe[0]);
dup2(stdout_pipe[1], STDOUT_FILENO);
dup2(stderr_pipe[1], STDERR_FILENO);
close(stdout_pipe[1]);
close(stderr_pipe[1]);
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
execvp(argv[0], argv.data());
_exit(EXIT_FAILURE);
} else {
close(stdout_pipe[1]);
Expand Down Expand Up @@ -316,21 +328,23 @@ compile_count_guard acquire_compile_slot() {
void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";

// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) ? "" : "-O";

#ifdef _WIN32
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
#else
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path};
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
#endif

// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
// disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
if (coopmat || name.find("bf16") != std::string::npos || name.find("rope") != std::string::npos) {
cmd.push_back("-O");
}

if (dep_file) {
cmd.push_back("-MD");
cmd.push_back("-MF");
cmd.push_back("\"" + target_cpp + ".d\"");
cmd.push_back(target_cpp + ".d");
}

#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
Expand All @@ -354,9 +368,13 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
// }
// std::cout << std::endl;

execute_command(command, stdout_str, stderr_str);
execute_command(cmd, stdout_str, stderr_str);
if (!stderr_str.empty()) {
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
std::cerr << "cannot compile " << name << "\n\n";
for (const auto& part : cmd) {
std::cerr << part << " ";
}
std::cerr << "\n\n" << stderr_str << std::endl;
return;
}

Expand Down Expand Up @@ -430,7 +448,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}

if (coopmat) {
Expand Down Expand Up @@ -610,7 +628,7 @@ void process_shaders() {
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
}

for (const auto& tname : type_names) {
Expand Down Expand Up @@ -1081,11 +1099,6 @@ int main(int argc, char** argv) {

if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc

if (!std::filesystem::exists(GLSLC) || !std::filesystem::is_regular_file(GLSLC)) {
std::cerr << "Error: glslc not found at " << GLSLC << std::endl;
return EXIT_FAILURE;
}
}
if (args.find("--source") != args.end()) {
input_filepath = args["--source"]; // The shader source file to compile
Expand Down
Loading