@@ -4116,6 +4116,94 @@ struct test_conv_2d : public test_case {
41164116 }
41174117};
41184118
4119+ // CONV_2D_IMPLICIT
4120+ struct test_conv_2d_implicit : public test_case {
4121+ const std::array<int64_t , 4 > ne_input;
4122+ const std::array<int64_t , 4 > ne_kernel;
4123+ const ggml_type type_kernel;
4124+ const int stride0;
4125+ const int stride1;
4126+ const int padding0;
4127+ const int padding1;
4128+ const int dilation0;
4129+ const int dilation1;
4130+ // Whether the inputs are contiguous in the channel dim or the width dim
4131+ const bool cwhn;
4132+
4133+
4134+
4135+ std::string vars () override {
4136+ return VARS_TO_STR10 (ne_input, ne_kernel, type_kernel, stride0, stride1, padding0, padding1, dilation0, dilation1, cwhn);
4137+ }
4138+
4139+ double max_nmse_err () override {
4140+ return 5e-4 ;
4141+ }
4142+
4143+ uint64_t op_flops (ggml_tensor * t) override {
4144+ GGML_UNUSED (t);
4145+ // Just counting matmul costs:
4146+ // KxCRS @ CRSxNPQ = KxNPQ --> KxNPQx(CRS+CRS-1) flops
4147+
4148+ // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d)
4149+ auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t {
4150+ return (ins + 2 * p - d * (ks - 1 ) - 1 ) / s + 1 ;
4151+ };
4152+
4153+ int64_t W = ne_input[0 ];
4154+ int64_t H = ne_input[1 ];
4155+ int64_t KW = ne_kernel[0 ];
4156+ int64_t KH = ne_kernel[1 ];
4157+ int64_t Cin = ne_kernel[2 ];
4158+ int64_t Cout = ne_kernel[3 ];
4159+ int64_t N = ne_input[3 ];
4160+ int64_t OH = calc_conv_output_size (H, KH, stride0, padding0, dilation0);
4161+ int64_t OW = calc_conv_output_size (W, KW, stride0, padding0, dilation0);
4162+
4163+ int64_t K = Cout;
4164+ int64_t CRS = Cin * KH * KW;
4165+ int64_t NPQ = N * OH * OW;
4166+
4167+ return K * NPQ * (2 * CRS - 1 );
4168+ }
4169+
4170+ test_conv_2d_implicit (std::array<int64_t , 4 > ne_input = { 64 , 64 , 16 , 1 },
4171+ std::array<int64_t , 4 > ne_kernel = { 3 , 3 , 1 , 16 }, ggml_type type_kernel = GGML_TYPE_F32, int stride0 = 1 ,
4172+ int stride1 = 1 , int padding0 = 0 , int padding1 = 0 , int dilation0 = 1 , int dilation1 = 1 , bool cwhn = false ) :
4173+ ne_input (ne_input),
4174+ ne_kernel (ne_kernel),
4175+ type_kernel (type_kernel),
4176+ stride0 (stride0),
4177+ stride1 (stride1),
4178+ padding0 (padding0),
4179+ padding1 (padding1),
4180+ dilation0 (dilation0),
4181+ dilation1 (dilation1),
4182+ cwhn (cwhn) {}
4183+
4184+ ggml_tensor * build_graph (ggml_context * ctx) override {
4185+ ggml_tensor * input = ggml_new_tensor (ctx, GGML_TYPE_F32, 4 , ne_input.data ());
4186+ ggml_set_name (input, " input" );
4187+
4188+ ggml_tensor * kernel = ggml_new_tensor (ctx, type_kernel, 4 , ne_kernel.data ());
4189+ ggml_set_name (kernel, " kernel" );
4190+
4191+ if (cwhn) {
4192+ // change memory layout to channel-most-contiguous (CWHN),
4193+ // then permute it back so NE matches the original input
4194+ input = ggml_cont (ctx, ggml_permute (ctx, input, 1 , 2 , 0 , 3 ));
4195+ input = ggml_permute (ctx, input, 2 , 0 , 1 , 3 );
4196+ kernel = ggml_cont (ctx, ggml_permute (ctx, kernel, 2 , 3 , 1 , 0 ));
4197+ kernel = ggml_permute (ctx, kernel, 3 , 2 , 0 , 1 );
4198+ }
4199+
4200+ ggml_tensor * out =
4201+ ggml_conv_2d_implicitgemm (ctx, kernel, input, stride0, stride1, padding0, padding1, dilation0, dilation1);
4202+ ggml_set_name (out, " out" );
4203+ return out;
4204+ }
4205+ };
4206+
41194207// GGML_OP_CONV_2D_DW
41204208struct test_conv_2d_dw : public test_case {
41214209 const std::array<int64_t , 4 > ne_input;
@@ -6454,6 +6542,17 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
64546542 }
64556543 }
64566544
6545+ for (auto kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
6546+ for (auto act_case : cases) {
6547+ // Direct CONV_2D
6548+ test_cases.emplace_back (new test_conv_2d_implicit (
6549+ { act_case[iwh_idx], act_case[iwh_idx], act_case[Cin_idx], act_case[B_idx] },
6550+ { act_case[kwh_idx], act_case[kwh_idx], act_case[Cin_idx], act_case[Cout_idx] },
6551+ kernel_type, 1 , 1 , 0 , 0 , 1 , 1 , false ));
6552+ }
6553+ }
6554+
6555+
64576556 test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 1 , 1 , 1 }));
64586557 test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {4096 , 1 , 1 , 1 }, {1 , 512 , 1 , 1 }));
64596558
0 commit comments