@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
271271 }
272272}
273273
274+ static std::string var_to_str (ggml_scale_mode mode) {
275+ switch (mode) {
276+ case GGML_SCALE_MODE_NEAREST: return " nearest" ;
277+ case GGML_SCALE_MODE_BILINEAR: return " bilinear" ;
278+ default : return std::to_string (mode);
279+ }
280+ }
281+
274282#define VAR_TO_STR (x ) (#x " =" + var_to_str(x))
275283
276284#define VARS_TO_STR1 (a ) VAR_TO_STR(a)
@@ -2948,15 +2956,16 @@ struct test_upscale : public test_case {
29482956 const std::array<int64_t , 4 > ne;
29492957 const int32_t scale_factor;
29502958 const bool transpose;
2959+ const ggml_scale_mode mode;
29512960
29522961 std::string vars () override {
2953- return VARS_TO_STR4 (type, ne, scale_factor, transpose);
2962+ return VARS_TO_STR5 (type, ne, scale_factor, mode , transpose);
29542963 }
29552964
29562965 test_upscale (ggml_type type = GGML_TYPE_F32,
29572966 std::array<int64_t , 4 > ne = {512 , 512 , 3 , 1 },
2958- int32_t scale_factor = 2 , bool transpose = false )
2959- : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
2967+ int32_t scale_factor = 2 , ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false )
2968+ : type(type), ne(ne), scale_factor(scale_factor), mode(mode), transpose(transpose) {}
29602969
29612970 ggml_tensor * build_graph (ggml_context * ctx) override {
29622971 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2967,7 +2976,7 @@ struct test_upscale : public test_case {
29672976 ggml_set_name (a, " a_transposed" );
29682977 }
29692978
2970- ggml_tensor * out = ggml_upscale (ctx, a, scale_factor);
2979+ ggml_tensor * out = ggml_upscale (ctx, a, scale_factor, mode );
29712980 ggml_set_name (out, " out" );
29722981
29732982 return out;
@@ -2979,21 +2988,23 @@ struct test_upscale_ext : public test_case {
29792988 const ggml_type type;
29802989 const std::array<int64_t , 4 > ne;
29812990 const std::array<int64_t , 4 > ne_tgt;
2991+ const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
29822992
29832993 std::string vars () override {
2984- return VARS_TO_STR3 (type, ne, ne_tgt);
2994+ return VARS_TO_STR4 (type, ne, ne_tgt, mode );
29852995 }
29862996
29872997 test_upscale_ext (ggml_type type = GGML_TYPE_F32,
29882998 std::array<int64_t , 4 > ne = {2 , 5 , 7 , 11 },
2989- std::array<int64_t , 4 > ne_tgt = {5 , 7 , 11 , 13 })
2990- : type(type), ne(ne), ne_tgt(ne_tgt) {}
2999+ std::array<int64_t , 4 > ne_tgt = {5 , 7 , 11 , 13 },
3000+ ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
3001+ : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
29913002
29923003 ggml_tensor * build_graph (ggml_context * ctx) override {
29933004 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
29943005 ggml_set_name (a, " a" );
29953006
2996- ggml_tensor * out = ggml_upscale_ext (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ]);
3007+ ggml_tensor * out = ggml_upscale_ext (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ], mode );
29973008 ggml_set_name (out, " out" );
29983009
29993010 return out;
@@ -4399,12 +4410,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
43994410 test_cases.emplace_back (new test_argsort (GGML_TYPE_F32, {60 , 10 , 10 , 10 }, order)); // qwen
44004411 }
44014412
4413+ for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
4414+ test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode));
4415+ test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode, true ));
4416+ test_cases.emplace_back (new test_upscale_ext (GGML_TYPE_F32, {2 , 5 , 7 , 11 }, {5 , 7 , 11 , 13 }, mode));
4417+ }
4418+
44024419 test_cases.emplace_back (new test_sum ());
44034420 test_cases.emplace_back (new test_sum_rows ());
44044421 test_cases.emplace_back (new test_mean ());
4405- test_cases.emplace_back (new test_upscale ());
4406- test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, { 512 , 512 , 3 , 1 }, 2 , true ));
4407- test_cases.emplace_back (new test_upscale_ext ());
44084422 test_cases.emplace_back (new test_group_norm (GGML_TYPE_F32, {64 , 64 , 320 , 1 }));
44094423 test_cases.emplace_back (new test_group_norm (GGML_TYPE_F32, {9 , 9 , 1280 , 1 }));
44104424 test_cases.emplace_back (new test_acc ());
0 commit comments