@@ -3296,28 +3296,28 @@ struct test_upscale : public test_case {
3296
3296
}
3297
3297
};
3298
3298
3299
- // GGML_OP_UPSCALE (ext )
3300
- struct test_upscale_ext : public test_case {
3299
+ // GGML_OP_UPSCALE (via ggml_interpolate )
3300
+ struct test_interpolate : public test_case {
3301
3301
const ggml_type type;
3302
3302
const std::array<int64_t , 4 > ne;
3303
3303
const std::array<int64_t , 4 > ne_tgt;
3304
- const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
3304
+ const uint32_t mode = GGML_SCALE_MODE_NEAREST;
3305
3305
3306
3306
std::string vars () override {
3307
3307
return VARS_TO_STR4 (type, ne, ne_tgt, mode);
3308
3308
}
3309
3309
3310
- test_upscale_ext (ggml_type type = GGML_TYPE_F32,
3310
+ test_interpolate (ggml_type type = GGML_TYPE_F32,
3311
3311
std::array<int64_t , 4 > ne = {2 , 5 , 7 , 11 },
3312
3312
std::array<int64_t , 4 > ne_tgt = {5 , 7 , 11 , 13 },
3313
- ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
3313
+ uint32_t mode = GGML_SCALE_MODE_NEAREST)
3314
3314
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
3315
3315
3316
3316
ggml_tensor * build_graph (ggml_context * ctx) override {
3317
3317
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
3318
3318
ggml_set_name (a, " a" );
3319
3319
3320
- ggml_tensor * out = ggml_upscale_ext (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ], mode);
3320
+ ggml_tensor * out = ggml_interpolate (ctx, a, ne_tgt[0 ], ne_tgt[1 ],ne_tgt[2 ], ne_tgt[3 ], mode);
3321
3321
ggml_set_name (out, " out" );
3322
3322
3323
3323
return out;
@@ -4799,8 +4799,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4799
4799
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
4800
4800
test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode));
4801
4801
test_cases.emplace_back (new test_upscale (GGML_TYPE_F32, {512 , 512 , 3 , 2 }, 2 , mode, true ));
4802
- test_cases.emplace_back (new test_upscale_ext (GGML_TYPE_F32, {2 , 5 , 7 , 11 }, {5 , 7 , 11 , 13 }, mode));
4802
+ test_cases.emplace_back (new test_interpolate (GGML_TYPE_F32, {2 , 5 , 7 , 11 }, {5 , 7 , 11 , 13 }, mode));
4803
+ test_cases.emplace_back (new test_interpolate (GGML_TYPE_F32, {5 , 7 , 11 , 13 }, {2 , 5 , 7 , 11 }, mode));
4803
4804
}
4805
+ test_cases.emplace_back (new test_interpolate (GGML_TYPE_F32, {2 , 5 , 7 , 11 }, {5 , 7 , 11 , 13 }, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
4804
4806
4805
4807
test_cases.emplace_back (new test_sum ());
4806
4808
test_cases.emplace_back (new test_sum_rows ());
0 commit comments