@@ -3707,6 +3707,7 @@ struct test_im2col : public test_case {
37073707struct test_conv_2d : public test_case {
37083708 const std::array<int64_t , 4 > ne_input;
37093709 const std::array<int64_t , 4 > ne_kernel;
3710+ const ggml_type type_kernel;
37103711 const int stride0;
37113712 const int stride1;
37123713 const int padding0;
@@ -3724,7 +3725,7 @@ struct test_conv_2d : public test_case {
37243725 // IM2COL -> MUL_MM graph will be built.
37253726
37263727 std::string vars () override {
3727- return VARS_TO_STR9 (ne_input, ne_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
3728+ return VARS_TO_STR10 (ne_input, ne_kernel, type_kernel , stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
37283729 }
37293730
37303731 uint64_t op_flops (ggml_tensor * t) override {
@@ -3755,10 +3756,11 @@ struct test_conv_2d : public test_case {
37553756 }
37563757
37573758 test_conv_2d (std::array<int64_t , 4 > ne_input = { 64 , 64 , 16 , 1 },
3758- std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, int stride0 = 1 , int stride1 = 1 , int padding0 = 0 ,
3759- int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
3759+ std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, ggml_type type_kernel = GGML_TYPE_F32 , int stride0 = 1 ,
3760+ int stride1 = 1 , int padding0 = 0 , int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
37603761 ne_input (ne_input),
37613762 ne_kernel (ne_kernel),
3763+ type_kernel (type_kernel),
37623764 stride0 (stride0),
37633765 stride1 (stride1),
37643766 padding0 (padding0),
@@ -3771,7 +3773,7 @@ struct test_conv_2d : public test_case {
37713773 ggml_tensor * input = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne_input.data ());
37723774 ggml_set_name (input, " input" );
37733775
3774- ggml_tensor * kernel = ggml_new_tensor (ctx, GGML_TYPE_F32 , 4 , ne_kernel.data ());
3776+ ggml_tensor * kernel = ggml_new_tensor (ctx, type_kernel , 4 , ne_kernel.data ());
37753777 ggml_set_name (kernel, " kernel" );
37763778
37773779 if (cwhn) {
@@ -5141,7 +5143,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51415143 for (auto act_case : cases) {
51425144 test_cases.emplace_back (new test_conv_2d (
51435145 { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5144- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5146+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5147+ GGML_TYPE_F32, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5148+ test_cases.emplace_back (new test_conv_2d (
5149+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5150+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5151+ GGML_TYPE_F16, 1 , 1 , 0 , 0 , 1 , 1 , false ));
51455152 }
51465153#endif
51475154
@@ -5168,7 +5175,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
51685175 if (calc_conv_output_size (W, KW, s0, p0, d0) > 0 &&
51695176 calc_conv_output_size (H, KH, s1, p1, d1) > 0 ) {
51705177 test_cases.emplace_back (new test_conv_2d (
5171- { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, s0, s1, p0, p1, d0, d1, false ));
5178+ { W, H, Cin, 2 }, { KW, KH, Cin, Cout }, GGML_TYPE_F32, s0, s1, p0, p1, d0, d1, false ));
51725179 }
51735180 }
51745181 }
@@ -5817,7 +5824,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
58175824 // Direct CONV_2D
58185825 test_cases.emplace_back (new test_conv_2d (
58195826 { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5820- { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] }, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5827+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5828+ GGML_TYPE_F32, 1 , 1 , 0 , 0 , 1 , 1 , false ));
5829+ test_cases.emplace_back (new test_conv_2d (
5830+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
5831+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
5832+ GGML_TYPE_F16, 1 , 1 , 0 , 0 , 1 , 1 , false ));
58215833 }
58225834
58235835 test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 1 , 1 , 1 }));
0 commit comments