@@ -485,10 +485,12 @@ void process_shaders() {
485485 string_to_spv (" cpy_f32_f32" , " copy.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
486486 string_to_spv (" cpy_f32_f16" , " copy.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float16_t" }});
487487 string_to_spv (" cpy_f16_f16" , " copy.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" OPTIMIZATION_ERROR_WORKAROUND" , " 1" }});
488+ string_to_spv (" cpy_f16_f32" , " copy.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float" }, {" OPTIMIZATION_ERROR_WORKAROUND" , " 1" }});
488489 string_to_spv (" cpy_f32_bf16" ," copy.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " uint16_t" }, {" DATA_D_BF16" , " 1" }});
489490 string_to_spv (" contig_cpy_f32_f32" , " contig_copy.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
490491 string_to_spv (" contig_cpy_f32_f16" , " contig_copy.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float16_t" }});
491492 string_to_spv (" contig_cpy_f16_f16" , " contig_copy.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }, {" OPTIMIZATION_ERROR_WORKAROUND" , " 1" }});
493+ string_to_spv (" contig_cpy_f16_f32" , " contig_copy.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float" }, {" OPTIMIZATION_ERROR_WORKAROUND" , " 1" }});
492494 string_to_spv (" contig_cpy_f32_bf16" ," contig_copy.comp" ,{{" A_TYPE" , " float" }, {" D_TYPE" , " uint16_t" }, {" DATA_D_BF16" , " 1" }});
493495
494496 for (std::string t : {" q4_0" , " q4_1" , " q5_0" , " q5_1" , " q8_0" , " iq4_nl" }) {
@@ -497,8 +499,26 @@ void process_shaders() {
497499 string_to_spv (" cpy_" + t + " _f32" , " copy_from_quant.comp" , {{" DATA_A_" + to_uppercase (t), " 1" }, {" D_TYPE" , " float" }, {" FLOAT_TYPE" , " float" }});
498500 }
499501
500- string_to_spv (" add_f32" , " add.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }, {" FLOAT_TYPE" , " float" }});
501- string_to_spv (" add_f16_f32_f16" , " add.comp" , {{" A_TYPE" , " float16_t" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float16_t" }, {" FLOAT_TYPE" , " float" }});
502+ auto get_type_str = [](bool f16 ) {
503+ return f16 ? " float16_t" : " float" ;
504+ };
505+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
506+ std::string s;
507+ s += std::string (src0_f16 ? " _f16" : " _f32" );
508+ s += std::string (src1_f16 ? " _f16" : " _f32" );
509+ s += std::string (dst_f16 ? " _f16" : " _f32" );
510+ return s;
511+ };
512+ for (std::string op : {" add" , " sub" , " mul" , " div" }) {
513+ for (auto src0_f16 : {false , true }) {
514+ for (auto src1_f16 : {false , true }) {
515+ for (auto dst_f16 : {false , true }) {
516+ auto name = op + get_suffix (src0_f16, src1_f16, dst_f16);
517+ string_to_spv (name.c_str (), op + " .comp" , {{" A_TYPE" , get_type_str (src0_f16)}, {" B_TYPE" , get_type_str (src1_f16)}, {" D_TYPE" , get_type_str (dst_f16)}, {" FLOAT_TYPE" , " float" }});
518+ }
519+ }
520+ }
521+ }
502522
503523 string_to_spv (" sub_f32" , " sub.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }, {" FLOAT_TYPE" , " float" }});
504524
@@ -533,14 +553,21 @@ void process_shaders() {
533553
534554 string_to_spv (" upscale_f32" , " upscale.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }});
535555
536- string_to_spv (" gelu_f32" , " gelu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
537- string_to_spv (" gelu_quick_f32" , " gelu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
538- string_to_spv (" silu_f32" , " silu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
539- string_to_spv (" silu_back_f32" , " silu_back.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }});
540- string_to_spv (" relu_f32" , " relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
541- string_to_spv (" leaky_relu_f32" , " leaky_relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
542- string_to_spv (" tanh_f32" , " tanh.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
543- string_to_spv (" sigmoid_f32" , " sigmoid.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
556+ string_to_spv (" gelu_f16" , " gelu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
557+ string_to_spv (" gelu_f32" , " gelu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
558+ string_to_spv (" gelu_quick_f16" , " gelu_quick.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
559+ string_to_spv (" gelu_quick_f32" , " gelu_quick.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
560+ string_to_spv (" silu_f16" , " silu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
561+ string_to_spv (" silu_f32" , " silu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
562+ string_to_spv (" relu_f16" , " relu.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
563+ string_to_spv (" relu_f32" , " relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
564+ string_to_spv (" tanh_f16" , " tanh.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
565+ string_to_spv (" tanh_f32" , " tanh.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
566+ string_to_spv (" sigmoid_f16" , " sigmoid.comp" , {{" A_TYPE" , " float16_t" }, {" D_TYPE" , " float16_t" }});
567+ string_to_spv (" sigmoid_f32" , " sigmoid.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
568+
569+ string_to_spv (" leaky_relu_f32" , " leaky_relu.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
570+ string_to_spv (" silu_back_f32" , " silu_back.comp" , {{" A_TYPE" , " float" }, {" B_TYPE" , " float" }, {" D_TYPE" , " float" }});
544571
545572 string_to_spv (" diag_mask_inf_f32" , " diag_mask_inf.comp" , {{" A_TYPE" , " float" }, {" D_TYPE" , " float" }});
546573
@@ -641,7 +668,12 @@ void write_output_files() {
641668 std::remove (path.c_str ());
642669 }
643670 }
644-
671+ for (const char *op : {" add" , " sub" , " mul" , " div" }) {
672+ fprintf (hdr, " extern unsigned char *%s_data[2][2][2];\n " , op);
673+ fprintf (hdr, " extern uint64_t %s_len[2][2][2];\n " , op);
674+ fprintf (src, " unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n " , op, op, op, op, op, op, op, op, op);
675+ fprintf (src, " uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n " , op, op, op, op, op, op, op, op, op);
676+ }
645677 fclose (hdr);
646678 fclose (src);
647679}
0 commit comments