@@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
33083308 }
33093309 }
33103310
3311- test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3312- test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3313- test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3314- // test cases for 1D im2col
3311+ // im2col 1D
33153312 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000 , 128 , 1 , 1 }, {3 , 128 , 1280 , 1 }, 1 , 0 , 1 , 0 , 1 , 0 , false ));
33163313 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000 , 128 , 1 , 1 }, {3 , 128 , 1280 , 1 }, 1 , 0 , 1 , 0 , 1 , 0 , false ));
33173314 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000 , 128 , 1 , 1 }, {3 , 128 , 1280 , 1 }, 1 , 0 , 1 , 0 , 1 , 0 , false ));
3315+ for (int s0 : {1 , 3 }) {
3316+ for (int p0 : {0 , 3 }) {
3317+ for (int d0 : {1 , 3 }) {
3318+ test_cases.emplace_back (new test_im2col (
3319+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20 , 2 , 2 , 1 }, {3 , 2 , 2 , 1 },
3320+ s0, 0 , p0, 0 , d0, 0 , false ));
3321+ }
3322+ }
3323+ }
3324+
3325+ // im2col 2D
3326+ test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
3327+ test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
3328+ test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
3329+ for (int s0 : {1 , 3 }) {
3330+ for (int s1 : {1 , 3 }) {
3331+ for (int p0 : {0 , 3 }) {
3332+ for (int p1 : {0 , 3 }) {
3333+ for (int d0 : {1 , 3 }) {
3334+ for (int d1 : {1 , 3 }) {
3335+ test_cases.emplace_back (new test_im2col (
3336+ GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20 , 20 , 2 , 2 }, {3 , 3 , 2 , 2 },
3337+ s0, s1, p0, p1, d0, d1, true ));
3338+ }
3339+ }
3340+ }
3341+ }
3342+ }
3343+ }
33183344
3319- // test cases for 2D im2col
3345+ // extra tests for im2col 2D
33203346 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12 , 12 , 1 , 32 }, {3 , 3 , 1 , 32 }, 1 , 1 , 1 , 1 , 1 , 1 , true ));
33213347 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12 , 12 , 2 , 32 }, {3 , 3 , 2 , 32 }, 1 , 1 , 1 , 1 , 1 , 1 , true ));
33223348 test_cases.emplace_back (new test_im2col (GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12 , 12 , 1 , 1024 }, {3 , 3 , 1 , 1024 }, 1 , 1 , 1 , 1 , 1 , 1 , true ));
0 commit comments