@@ -537,8 +537,10 @@ void process_shaders() {
537537    for  (auto  src0_f16 : {false , true }) {
538538    for  (auto  src1_f16 : {false , true }) {
539539    for  (auto  dst_f16  : {false , true }) {
540-         auto  name = op + get_suffix (src0_f16, src1_f16, dst_f16);
541-         string_to_spv (name.c_str (), op + " .comp"  , {{" A_TYPE"  , get_type_str (src0_f16)}, {" B_TYPE"  , get_type_str (src1_f16)}, {" D_TYPE"  , get_type_str (dst_f16)}, {" FLOAT_TYPE"  , " float"  }});
540+     for  (auto  rte      : {false , true }) {
541+         auto  name = op + get_suffix (src0_f16, src1_f16, dst_f16) + (rte ? " _rte"   : " "  );
542+         string_to_spv (name.c_str (), op + " .comp"  , {{" A_TYPE"  , get_type_str (src0_f16)}, {" B_TYPE"  , get_type_str (src1_f16)}, {" D_TYPE"  , get_type_str (dst_f16)}, {" FLOAT_TYPE"  , " float"  }, {" RTE16"  , rte ? " 1"   : " 0"  }});
543+     }
542544    }
543545    }
544546    }
@@ -592,16 +594,19 @@ void process_shaders() {
592594    string_to_spv (" sigmoid_f16"  ,    " sigmoid.comp"  ,     {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
593595    string_to_spv (" sigmoid_f32"  ,    " sigmoid.comp"  ,     {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
594596
595-     string_to_spv (" geglu_f16"  ,      " geglu.comp"  ,       {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
596-     string_to_spv (" geglu_f32"  ,      " geglu.comp"  ,       {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
597-     string_to_spv (" reglu_f16"  ,      " reglu.comp"  ,       {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
598-     string_to_spv (" reglu_f32"  ,      " reglu.comp"  ,       {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
599-     string_to_spv (" swiglu_f16"  ,     " swiglu.comp"  ,      {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
600-     string_to_spv (" swiglu_f32"  ,     " swiglu.comp"  ,      {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
601-     string_to_spv (" geglu_erf_f16"  ,  " geglu_erf.comp"  ,   {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
602-     string_to_spv (" geglu_erf_f32"  ,  " geglu_erf.comp"  ,   {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
603-     string_to_spv (" geglu_quick_f16"  ," geglu_quick.comp"  , {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  }});
604-     string_to_spv (" geglu_quick_f32"  ," geglu_quick.comp"  , {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  }});
597+     for  (auto  rte : {false , true }) {
598+         std::string suffix = rte ? " _rte"   : " "  ;
599+         string_to_spv (" geglu_f16"   + suffix,      " geglu.comp"  ,       {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  },   {" RTE16"  , rte ? " 1"   : " 0"  }});
600+         string_to_spv (" geglu_f32"   + suffix,      " geglu.comp"  ,       {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  },       {" RTE16"  , rte ? " 1"   : " 0"  }});
601+         string_to_spv (" reglu_f16"   + suffix,      " reglu.comp"  ,       {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  },   {" RTE16"  , rte ? " 1"   : " 0"  }});
602+         string_to_spv (" reglu_f32"   + suffix,      " reglu.comp"  ,       {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  },       {" RTE16"  , rte ? " 1"   : " 0"  }});
603+         string_to_spv (" swiglu_f16"   + suffix,     " swiglu.comp"  ,      {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  },   {" RTE16"  , rte ? " 1"   : " 0"  }});
604+         string_to_spv (" swiglu_f32"   + suffix,     " swiglu.comp"  ,      {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  },       {" RTE16"  , rte ? " 1"   : " 0"  }});
605+         string_to_spv (" geglu_erf_f16"   + suffix,  " geglu_erf.comp"  ,   {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  },   {" RTE16"  , rte ? " 1"   : " 0"  }});
606+         string_to_spv (" geglu_erf_f32"   + suffix,  " geglu_erf.comp"  ,   {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  },       {" RTE16"  , rte ? " 1"   : " 0"  }});
607+         string_to_spv (" geglu_quick_f16"   + suffix," geglu_quick.comp"  , {{" A_TYPE"  , " float16_t"  },   {" D_TYPE"  , " float16_t"  },   {" RTE16"  , rte ? " 1"   : " 0"  }});
608+         string_to_spv (" geglu_quick_f32"   + suffix," geglu_quick.comp"  , {{" A_TYPE"  , " float"  },       {" D_TYPE"  , " float"  },       {" RTE16"  , rte ? " 1"   : " 0"  }});
609+     }
605610
606611    string_to_spv (" leaky_relu_f32"  , " leaky_relu.comp"  ,  {{" A_TYPE"  , " float"  }, {" D_TYPE"  , " float"  }});
607612    string_to_spv (" silu_back_f32"  ,  " silu_back.comp"  ,   {{" A_TYPE"  , " float"  }, {" B_TYPE"  , " float"  }, {" D_TYPE"  , " float"  }});
@@ -709,11 +714,59 @@ void write_output_files() {
709714            std::remove (path.c_str ());
710715        }
711716    }
717+ 
718+     std::string suffixes[2 ] = {" _f32"  , " _f16"  };
712719    for  (const  char  *op : {" add"  , " sub"  , " mul"  , " div"  }) {
713-         fprintf (hdr, " extern unsigned char *%s_data[2][2][2];\n "  , op);
714-         fprintf (hdr, " extern uint64_t %s_len[2][2][2];\n "  , op);
715-         fprintf (src, " unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n "  , op, op, op, op, op, op, op, op, op);
716-         fprintf (src, " uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n "  , op, op, op, op, op, op, op, op, op);
720+         fprintf (hdr, " extern unsigned char *%s_data[2][2][2][2];\n "  , op);
721+         fprintf (hdr, " extern uint64_t %s_len[2][2][2][2];\n "  , op);
722+         std::string data = " unsigned char *"   + std::string (op) + " _data[2][2][2][2] = "  ;
723+         std::string len = " uint64_t "   + std::string (op) + " _len[2][2][2][2] = "  ;
724+         for  (uint32_t  t0 = 0 ; t0 < 2 ; ++t0) {
725+             if  (t0 == 0 ) {
726+                 data += " {"  ;
727+                 len += " {"  ;
728+             }
729+             for  (uint32_t  t1 = 0 ; t1 < 2 ; ++t1) {
730+                 if  (t1 == 0 ) {
731+                     data += " {"  ;
732+                     len += " {"  ;
733+                 }
734+                 for  (uint32_t  t2 = 0 ; t2 < 2 ; ++t2) {
735+                     if  (t2 == 0 ) {
736+                         data += " {"  ;
737+                         len += " {"  ;
738+                     }
739+                     for  (uint32_t  rte = 0 ; rte < 2 ; ++rte) {
740+                         if  (rte == 0 ) {
741+                             data += " {"  ;
742+                             len += " {"  ;
743+                         }
744+                         data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte"   : " "  );
745+                         len  += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0 ) ? " _rte"   : " "  );
746+                         data += " _data,"  ;
747+                         len  += " _len,"  ;
748+                         if  (rte == 1 ) {
749+                             data += " }, "  ;
750+                             len += " }, "  ;
751+                         }
752+                     }
753+                     if  (t2 == 1 ) {
754+                         data += " }, "  ;
755+                         len += " }, "  ;
756+                     }
757+                 }
758+                 if  (t1 == 1 ) {
759+                     data += " }, "  ;
760+                     len += " }, "  ;
761+                 }
762+             }
763+             if  (t0 == 1 ) {
764+                 data += " };\n "  ;
765+                 len += " };\n "  ;
766+             }
767+         }
768+         fprintf (src, data.c_str ());
769+         fprintf (src, len.c_str ());
717770    }
718771    fclose (hdr);
719772    fclose (src);
0 commit comments