Skip to content

Commit b5bec86

Browse files
committed
simple quick triage for vulkan compilation
1 parent c83dde8 commit b5bec86

File tree

1 file changed

+152
-11
lines changed

1 file changed

+152
-11
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 152 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ std::vector<std::pair<std::string, std::string>> shader_fnames;
4848
std::locale c_locale("C");
4949

5050
std::string GLSLC = "glslc";
51+
std::string input_dir = "vulkan-shaders";
5152
std::string input_filepath = "";
5253
std::string output_dir = "/tmp";
5354
std::string target_hpp = "";
5455
std::string target_cpp = "";
56+
bool no_clean = false;
5557

5658
const std::vector<std::string> type_names = {
5759
"f32",
@@ -301,7 +303,7 @@ void write_file_if_changed(const std::string& path, const std::string& content)
301303
static uint32_t compile_count = 0;
302304
static std::mutex compile_count_mutex;
303305
static std::condition_variable compile_count_cond;
304-
static bool generate_dep_file = true;
306+
static bool generate_dep_file = false;
305307

306308
void decrement_compile_count(uint32_t * count) {
307309
if (count) {
@@ -402,18 +404,20 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
402404
name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
403405
std::string out_path = join_paths(output_dir, name + ".spv");
404406

405-
if (input_filepath == "") {
406-
// No input source to compile, only generate header for all shaders
407-
shader_fnames.push_back(std::pair(name, out_path));
408-
return;
409-
} else if (basename(input_filepath) != source) {
410-
// Only compile shader variants matching the input filename
411-
return;
412-
}
407+
// if (input_filepath == "") {
408+
// // No input source to compile, only generate header for all shaders
409+
// shader_fnames.push_back(std::pair(name, out_path));
410+
// return;
411+
// } else if (basename(input_filepath) != source) {
412+
// // Only compile shader variants matching the input filename
413+
// return;
414+
// }
415+
416+
std::string in_path = join_paths(input_dir, source);
413417

414418
compile_count_guard slot = acquire_compile_slot();
415419
compiles.push_back(std::async(
416-
string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
420+
string_to_spv_func, name, in_path, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
417421
// Don't write the same dep file from multiple processes
418422
generate_dep_file = false;
419423
}
@@ -1064,6 +1068,134 @@ void write_output_files() {
10641068
}
10651069
}
10661070

1071+
void write_output_files_combined() {
1072+
FILE* hdr = fopen(target_hpp.c_str(), "w");
1073+
FILE* src = fopen(target_cpp.c_str(), "w");
1074+
1075+
fprintf(hdr, "#include <cstdint>\n\n");
1076+
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
1077+
1078+
std::sort(shader_fnames.begin(), shader_fnames.end());
1079+
for (const auto& pair : shader_fnames) {
1080+
const std::string& name = pair.first;
1081+
#ifdef _WIN32
1082+
std::string path = pair.second;
1083+
std::replace(path.begin(), path.end(), '/', '\\' );
1084+
#else
1085+
const std::string& path = pair.second;
1086+
#endif
1087+
1088+
FILE* spv = fopen(path.c_str(), "rb");
1089+
if (!spv) {
1090+
std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
1091+
continue;
1092+
}
1093+
1094+
fseek(spv, 0, SEEK_END);
1095+
size_t size = ftell(spv);
1096+
fseek(spv, 0, SEEK_SET);
1097+
1098+
std::vector<unsigned char> data(size);
1099+
size_t read_size = fread(data.data(), 1, size, spv);
1100+
fclose(spv);
1101+
if (read_size != size) {
1102+
std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
1103+
continue;
1104+
}
1105+
1106+
fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
1107+
fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
1108+
1109+
fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
1110+
for (size_t i = 0; i < size; ++i) {
1111+
fprintf(src, "0x%02x,", data[i]);
1112+
if ((i + 1) % 12 == 0) fprintf(src, "\n");
1113+
}
1114+
fprintf(src, "\n};\n\n");
1115+
1116+
if (!no_clean) {
1117+
std::remove(path.c_str());
1118+
}
1119+
}
1120+
1121+
std::string suffixes[2] = {"_f32", "_f16"};
1122+
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
1123+
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
1124+
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
1125+
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
1126+
std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
1127+
for (uint32_t t0 = 0; t0 < 2; ++t0) {
1128+
if (t0 == 0) {
1129+
data += "{";
1130+
len += "{";
1131+
}
1132+
for (uint32_t t1 = 0; t1 < 2; ++t1) {
1133+
if (t1 == 0) {
1134+
data += "{";
1135+
len += "{";
1136+
}
1137+
for (uint32_t t2 = 0; t2 < 2; ++t2) {
1138+
if (t2 == 0) {
1139+
data += "{";
1140+
len += "{";
1141+
}
1142+
for (uint32_t rte = 0; rte < 2; ++rte) {
1143+
if (rte == 0) {
1144+
data += "{";
1145+
len += "{";
1146+
}
1147+
data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
1148+
len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
1149+
data += "_data,";
1150+
len += "_len,";
1151+
if (rte == 1) {
1152+
data += "}, ";
1153+
len += "}, ";
1154+
}
1155+
}
1156+
if (t2 == 1) {
1157+
data += "}, ";
1158+
len += "}, ";
1159+
}
1160+
}
1161+
if (t1 == 1) {
1162+
data += "}, ";
1163+
len += "}, ";
1164+
}
1165+
}
1166+
if (t0 == 1) {
1167+
data += "};\n";
1168+
len += "};\n";
1169+
}
1170+
}
1171+
fputs(data.c_str(), src);
1172+
fputs(len.c_str(), src);
1173+
}
1174+
1175+
std::vector<std::string> btypes = {"f16", "f32"};
1176+
1177+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
1178+
btypes.push_back("q8_1");
1179+
#endif
1180+
1181+
for (const std::string& btype : btypes) {
1182+
for (const auto& tname : type_names) {
1183+
if (btype == "q8_1" && !is_legacy_quant(tname)) {
1184+
continue;
1185+
}
1186+
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
1187+
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
1188+
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
1189+
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
1190+
fputs(data.c_str(), src);
1191+
fputs(len.c_str(), src);
1192+
}
1193+
}
1194+
1195+
fclose(hdr);
1196+
fclose(src);
1197+
}
1198+
10671199
} // namespace
10681200

10691201
int main(int argc, char** argv) {
@@ -1086,6 +1218,9 @@ int main(int argc, char** argv) {
10861218
if (args.find("--source") != args.end()) {
10871219
input_filepath = args["--source"]; // The shader source file to compile
10881220
}
1221+
if (args.find("--input-dir") != args.end()) {
1222+
input_dir = args["--input-dir"]; // Directory containing shader sources
1223+
}
10891224
if (args.find("--output-dir") != args.end()) {
10901225
output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
10911226
}
@@ -1096,6 +1231,11 @@ int main(int argc, char** argv) {
10961231
target_cpp = args["--target-cpp"]; // Path to generated cpp file
10971232
}
10981233

1234+
if (!directory_exists(input_dir)) {
1235+
std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
1236+
return EXIT_FAILURE;
1237+
}
1238+
10991239
if (!directory_exists(output_dir)) {
11001240
if (!create_directory(output_dir)) {
11011241
std::cerr << "Error creating output directory: " << output_dir << "\n";
@@ -1105,7 +1245,8 @@ int main(int argc, char** argv) {
11051245

11061246
process_shaders();
11071247

1108-
write_output_files();
1248+
//write_output_files();
1249+
write_output_files_combined();
11091250

11101251
return EXIT_SUCCESS;
11111252
}

0 commit comments

Comments
 (0)