@@ -48,10 +48,12 @@ std::vector<std::pair<std::string, std::string>> shader_fnames;
4848std::locale c_locale (" C" );
4949
5050std::string GLSLC = " glslc" ;
51+ std::string input_dir = " vulkan-shaders" ;
5152std::string input_filepath = " " ;
5253std::string output_dir = " /tmp" ;
5354std::string target_hpp = " " ;
5455std::string target_cpp = " " ;
56+ bool no_clean = false ;
5557
5658const std::vector<std::string> type_names = {
5759 " f32" ,
@@ -301,7 +303,7 @@ void write_file_if_changed(const std::string& path, const std::string& content)
301303static uint32_t compile_count = 0 ;
302304static std::mutex compile_count_mutex;
303305static std::condition_variable compile_count_cond;
304- static bool generate_dep_file = true ;
306+ static bool generate_dep_file = false ;
305307
306308void decrement_compile_count (uint32_t * count) {
307309 if (count) {
@@ -402,18 +404,20 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s
402404 name = name + (f16acc ? " _f16acc" : " " ) + (coopmat ? " _cm1" : " " ) + (coopmat2 ? " _cm2" : (fp16 ? " " : " _fp32" ));
403405 std::string out_path = join_paths (output_dir, name + " .spv" );
404406
405- if (input_filepath == " " ) {
406- // No input source to compile, only generate header for all shaders
407- shader_fnames.push_back (std::pair (name, out_path));
408- return ;
409- } else if (basename (input_filepath) != source) {
410- // Only compile shader variants matching the input filename
411- return ;
412- }
407+ // if (input_filepath == "") {
408+ // // No input source to compile, only generate header for all shaders
409+ // shader_fnames.push_back(std::pair(name, out_path));
410+ // return;
411+ // } else if (basename(input_filepath) != source) {
412+ // // Only compile shader variants matching the input filename
413+ // return;
414+ // }
415+
416+ std::string in_path = join_paths (input_dir, source);
413417
414418 compile_count_guard slot = acquire_compile_slot ();
415419 compiles.push_back (std::async (
416- string_to_spv_func, name, input_filepath , out_path, defines, coopmat, generate_dep_file, std::move (slot)));
420+ string_to_spv_func, name, in_path , out_path, defines, coopmat, generate_dep_file, std::move (slot)));
417421 // Don't write the same dep file from multiple processes
418422 generate_dep_file = false ;
419423}
@@ -1064,6 +1068,134 @@ void write_output_files() {
10641068 }
10651069}
10661070
1071+ void write_output_files_combined () {
1072+ FILE* hdr = fopen (target_hpp.c_str (), " w" );
1073+ FILE* src = fopen (target_cpp.c_str (), " w" );
1074+
1075+ fprintf (hdr, " #include <cstdint>\n\n " );
1076+ fprintf (src, " #include \" %s\"\n\n " , basename (target_hpp).c_str ());
1077+
1078+ std::sort (shader_fnames.begin (), shader_fnames.end ());
1079+ for (const auto & pair : shader_fnames) {
1080+ const std::string& name = pair.first ;
1081+ #ifdef _WIN32
1082+ std::string path = pair.second ;
1083+ std::replace (path.begin (), path.end (), ' /' , ' \\ ' );
1084+ #else
1085+ const std::string& path = pair.second ;
1086+ #endif
1087+
1088+ FILE* spv = fopen (path.c_str (), " rb" );
1089+ if (!spv) {
1090+ std::cerr << " Error opening SPIR-V file: " << path << " (" << strerror (errno) << " )\n " ;
1091+ continue ;
1092+ }
1093+
1094+ fseek (spv, 0 , SEEK_END);
1095+ size_t size = ftell (spv);
1096+ fseek (spv, 0 , SEEK_SET);
1097+
1098+ std::vector<unsigned char > data (size);
1099+ size_t read_size = fread (data.data (), 1 , size, spv);
1100+ fclose (spv);
1101+ if (read_size != size) {
1102+ std::cerr << " Error reading SPIR-V file: " << path << " (" << strerror (errno) << " )\n " ;
1103+ continue ;
1104+ }
1105+
1106+ fprintf (hdr, " extern unsigned char %s_data[%zu];\n " , name.c_str (), size);
1107+ fprintf (hdr, " const uint64_t %s_len = %zu;\n\n " , name.c_str (), size);
1108+
1109+ fprintf (src, " unsigned char %s_data[%zu] = {\n " , name.c_str (), size);
1110+ for (size_t i = 0 ; i < size; ++i) {
1111+ fprintf (src, " 0x%02x," , data[i]);
1112+ if ((i + 1 ) % 12 == 0 ) fprintf (src, " \n " );
1113+ }
1114+ fprintf (src, " \n };\n\n " );
1115+
1116+ if (!no_clean) {
1117+ std::remove (path.c_str ());
1118+ }
1119+ }
1120+
1121+ std::string suffixes[2 ] = {" _f32" , " _f16" };
1122+ for (const char *op : {" add" , " sub" , " mul" , " div" , " add_rms" }) {
1123+ fprintf (hdr, " extern unsigned char *%s_data[2][2][2][2];\n " , op);
1124+ fprintf (hdr, " extern uint64_t %s_len[2][2][2][2];\n " , op);
1125+ std::string data = " unsigned char *" + std::string (op) + " _data[2][2][2][2] = " ;
1126+ std::string len = " uint64_t " + std::string (op) + " _len[2][2][2][2] = " ;
1127+ for (uint32_t t0 = 0 ; t0 < 2 ; ++t0) {
1128+ if (t0 == 0 ) {
1129+ data += " {" ;
1130+ len += " {" ;
1131+ }
1132+ for (uint32_t t1 = 0 ; t1 < 2 ; ++t1) {
1133+ if (t1 == 0 ) {
1134+ data += " {" ;
1135+ len += " {" ;
1136+ }
1137+ for (uint32_t t2 = 0 ; t2 < 2 ; ++t2) {
1138+ if (t2 == 0 ) {
1139+ data += " {" ;
1140+ len += " {" ;
1141+ }
1142+ for (uint32_t rte = 0 ; rte < 2 ; ++rte) {
1143+ if (rte == 0 ) {
1144+ data += " {" ;
1145+ len += " {" ;
1146+ }
1147+ data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
1148+ len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
1149+ data += " _data," ;
1150+ len += " _len," ;
1151+ if (rte == 1 ) {
1152+ data += " }, " ;
1153+ len += " }, " ;
1154+ }
1155+ }
1156+ if (t2 == 1 ) {
1157+ data += " }, " ;
1158+ len += " }, " ;
1159+ }
1160+ }
1161+ if (t1 == 1 ) {
1162+ data += " }, " ;
1163+ len += " }, " ;
1164+ }
1165+ }
1166+ if (t0 == 1 ) {
1167+ data += " };\n " ;
1168+ len += " };\n " ;
1169+ }
1170+ }
1171+ fputs (data.c_str (), src);
1172+ fputs (len.c_str (), src);
1173+ }
1174+
1175+ std::vector<std::string> btypes = {" f16" , " f32" };
1176+
1177+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
1178+ btypes.push_back (" q8_1" );
1179+ #endif
1180+
1181+ for (const std::string& btype : btypes) {
1182+ for (const auto & tname : type_names) {
1183+ if (btype == " q8_1" && !is_legacy_quant (tname)) {
1184+ continue ;
1185+ }
1186+ fprintf (hdr, " extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n " , tname.c_str (), btype.c_str ());
1187+ fprintf (hdr, " extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n " , tname.c_str (), btype.c_str ());
1188+ std::string data = " unsigned char *arr_dmmv_" + tname + " _" + btype + " _f32_data[3] = {mul_mat_vec_" + tname + " _" + btype + " _f32_data, mul_mat_vec_" + tname + " _" + btype + " _f32_subgroup_data, mul_mat_vec_" + tname + " _" + btype + " _f32_subgroup_no_shmem_data};\n " ;
1189+ std::string len = " uint64_t arr_dmmv_" + tname + " _" + btype + " _f32_len[3] = {mul_mat_vec_" + tname + " _" + btype + " _f32_len, mul_mat_vec_" + tname + " _" + btype + " _f32_subgroup_len, mul_mat_vec_" + tname + " _" + btype + " _f32_subgroup_no_shmem_len};\n " ;
1190+ fputs (data.c_str (), src);
1191+ fputs (len.c_str (), src);
1192+ }
1193+ }
1194+
1195+ fclose (hdr);
1196+ fclose (src);
1197+ }
1198+
10671199} // namespace
10681200
10691201int main (int argc, char ** argv) {
@@ -1086,6 +1218,9 @@ int main(int argc, char** argv) {
10861218 if (args.find (" --source" ) != args.end ()) {
10871219 input_filepath = args[" --source" ]; // The shader source file to compile
10881220 }
1221+ if (args.find (" --input-dir" ) != args.end ()) {
1222+ input_dir = args[" --input-dir" ]; // Directory containing shader sources
1223+ }
10891224 if (args.find (" --output-dir" ) != args.end ()) {
10901225 output_dir = args[" --output-dir" ]; // Directory for containing SPIR-V output
10911226 }
@@ -1096,6 +1231,11 @@ int main(int argc, char** argv) {
10961231 target_cpp = args[" --target-cpp" ]; // Path to generated cpp file
10971232 }
10981233
1234+ if (!directory_exists (input_dir)) {
1235+ std::cerr << " \" " << input_dir << " \" must be a valid directory containing shader sources" << std::endl;
1236+ return EXIT_FAILURE;
1237+ }
1238+
10991239 if (!directory_exists (output_dir)) {
11001240 if (!create_directory (output_dir)) {
11011241 std::cerr << " Error creating output directory: " << output_dir << " \n " ;
@@ -1105,7 +1245,8 @@ int main(int argc, char** argv) {
11051245
11061246 process_shaders ();
11071247
1108- write_output_files ();
1248+ // write_output_files();
1249+ write_output_files_combined ();
11091250
11101251 return EXIT_SUCCESS;
11111252}
0 commit comments