@@ -537,8 +537,10 @@ void process_shaders() {
537537 for (auto src0_f16 : {false , true }) {
538538 for (auto src1_f16 : {false , true }) {
539539 for (auto dst_f16 : {false , true }) {
540- auto name = op + get_suffix (src0_f16, src1_f16, dst_f16);
541- string_to_spv (name.c_str (), op + " .comp" , {{" A_TYPE" , get_type_str (src0_f16)}, {" B_TYPE" , get_type_str (src1_f16)}, {" D_TYPE" , get_type_str (dst_f16)}, {" FLOAT_TYPE" , " float" }});
540+ for (auto rte : {false , true }) {
541+ auto name = op + get_suffix (src0_f16, src1_f16, dst_f16) + (rte ? " _rte" : " " );
542+ string_to_spv (name.c_str (), op + " .comp" , {{" A_TYPE" , get_type_str (src0_f16)}, {" B_TYPE" , get_type_str (src1_f16)}, {" D_TYPE" , get_type_str (dst_f16)}, {" FLOAT_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
543+ }
542544 }
543545 }
544546 }
@@ -592,16 +594,19 @@ void process_shaders() {
592594 string_to_spv (" sigmoid_f16" , " sigmoid.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
593595 string_to_spv (" sigmoid_f32" , " sigmoid.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
594596
595- string_to_spv (" geglu_f16" , " geglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
596- string_to_spv (" geglu_f32" , " geglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
597- string_to_spv (" reglu_f16" , " reglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
598- string_to_spv (" reglu_f32" , " reglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
599- string_to_spv (" swiglu_f16" , " swiglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
600- string_to_spv (" swiglu_f32" , " swiglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
601- string_to_spv (" geglu_erf_f16" , " geglu_erf.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
602- string_to_spv (" geglu_erf_f32" , " geglu_erf.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
603- string_to_spv (" geglu_quick_f16" ," geglu_quick.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
604- string_to_spv (" geglu_quick_f32" ," geglu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
597+ for (auto rte : {false , true }) {
598+ std::string suffix = rte ? " _rte" : " " ;
599+ string_to_spv (" geglu_f16" + suffix, " geglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
600+ string_to_spv (" geglu_f32" + suffix, " geglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
601+ string_to_spv (" reglu_f16" + suffix, " reglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
602+ string_to_spv (" reglu_f32" + suffix, " reglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
603+ string_to_spv (" swiglu_f16" + suffix, " swiglu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
604+ string_to_spv (" swiglu_f32" + suffix, " swiglu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
605+ string_to_spv (" geglu_erf_f16" + suffix, " geglu_erf.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
606+ string_to_spv (" geglu_erf_f32" + suffix, " geglu_erf.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
607+ string_to_spv (" geglu_quick_f16" + suffix," geglu_quick.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" RTE16" , rte ? " 1" : " 0" }});
608+ string_to_spv (" geglu_quick_f32" + suffix," geglu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }, {" RTE16" , rte ? " 1" : " 0" }});
609+ }
605610
606611 string_to_spv (" leaky_relu_f32" , " leaky_relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
607612 string_to_spv (" silu_back_f32" , " silu_back.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }});
@@ -709,11 +714,59 @@ void write_output_files() {
709714 std::remove (path.c_str ());
710715 }
711716 }
717+
718+ std::string suffixes[2 ] = {" _f32" , " _f16" };
712719 for (const char *op : {" add" , " sub" , " mul" , " div" }) {
713- fprintf (hdr, " extern unsigned char *%s_data[2][2][2];\n " , op);
714- fprintf (hdr, " extern uint64_t %s_len[2][2][2];\n " , op);
715- fprintf (src, " unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n " , op, op, op, op, op, op, op, op, op);
716- fprintf (src, " uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n " , op, op, op, op, op, op, op, op, op);
720+ fprintf (hdr, " extern unsigned char *%s_data[2][2][2][2];\n " , op);
721+ fprintf (hdr, " extern uint64_t %s_len[2][2][2][2];\n " , op);
722+ std::string data = " unsigned char *" + std::string (op) + " _data[2][2][2][2] = " ;
723+ std::string len = " uint64_t " + std::string (op) + " _len[2][2][2][2] = " ;
724+ for (uint32_t t0 = 0 ; t0 < 2 ; ++t0) {
725+ if (t0 == 0 ) {
726+ data += " {" ;
727+ len += " {" ;
728+ }
729+ for (uint32_t t1 = 0 ; t1 < 2 ; ++t1) {
730+ if (t1 == 0 ) {
731+ data += " {" ;
732+ len += " {" ;
733+ }
734+ for (uint32_t t2 = 0 ; t2 < 2 ; ++t2) {
735+ if (t2 == 0 ) {
736+ data += " {" ;
737+ len += " {" ;
738+ }
739+ for (uint32_t rte = 0 ; rte < 2 ; ++rte) {
740+ if (rte == 0 ) {
741+ data += " {" ;
742+ len += " {" ;
743+ }
744+ data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
745+ len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte" : " " );
746+ data += " _data," ;
747+ len += " _len," ;
748+ if (rte == 1 ) {
749+ data += " }, " ;
750+ len += " }, " ;
751+ }
752+ }
753+ if (t2 == 1 ) {
754+ data += " }, " ;
755+ len += " }, " ;
756+ }
757+ }
758+ if (t1 == 1 ) {
759+ data += " }, " ;
760+ len += " }, " ;
761+ }
762+ }
763+ if (t0 == 1 ) {
764+ data += " };\n " ;
765+ len += " };\n " ;
766+ }
767+ }
768+ fprintf (src, data.c_str ());
769+ fprintf (src, len.c_str ());
717770 }
718771 fclose (hdr);
719772 fclose (src);
0 commit comments