Skip to content

Commit fe92821

Browse files
slarenggerganov
authored andcommitted
ggml : add bilinear upscale support (ggml/1185)
1 parent 459895c commit fe92821

File tree

9 files changed

+122
-43
lines changed

9 files changed

+122
-43
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,24 +1717,29 @@ extern "C" {
17171717
float p0,
17181718
float p1);
17191719

1720-
// nearest interpolate
1720+
enum ggml_scale_mode {
1721+
GGML_SCALE_MODE_NEAREST = 0,
1722+
GGML_SCALE_MODE_BILINEAR = 1,
1723+
};
1724+
1725+
// interpolate
17211726
// multiplies ne0 and ne1 by scale factor
1722-
// used in stable-diffusion
17231727
GGML_API struct ggml_tensor * ggml_upscale(
17241728
struct ggml_context * ctx,
17251729
struct ggml_tensor * a,
1726-
int scale_factor);
1730+
int scale_factor,
1731+
enum ggml_scale_mode mode);
17271732

1728-
// nearest interpolate
1729-
// nearest interpolate to specified dimensions
1730-
// used in tortoise.cpp
1733+
// interpolate
1734+
// interpolate scale to specified dimensions
17311735
GGML_API struct ggml_tensor * ggml_upscale_ext(
17321736
struct ggml_context * ctx,
17331737
struct ggml_tensor * a,
17341738
int ne0,
17351739
int ne1,
17361740
int ne2,
1737-
int ne3);
1741+
int ne3,
1742+
enum ggml_scale_mode mode);
17381743

17391744
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
17401745
GGML_API struct ggml_tensor * ggml_pad(

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
18241824
if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
18251825
return false;
18261826
}
1827+
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
1828+
return false;
1829+
}
18271830
return true;
18281831
}
18291832
case GGML_OP_POOL_2D: {

ggml/src/ggml-cpu/ops.cpp

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6351,24 +6351,72 @@ static void ggml_compute_forward_upscale_f32(
63516351
const float sf2 = (float)ne2/src0->ne[2];
63526352
const float sf3 = (float)ne3/src0->ne[3];
63536353

6354-
// TODO: optimize
6355-
6356-
for (int64_t i3 = 0; i3 < ne3; i3++) {
6357-
const int64_t i03 = i3 / sf3;
6358-
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6359-
const int64_t i02 = i2 / sf2;
6360-
for (int64_t i1 = 0; i1 < ne1; i1++) {
6361-
const int64_t i01 = i1 / sf1;
6362-
for (int64_t i0 = 0; i0 < ne0; i0++) {
6363-
const int64_t i00 = i0 / sf0;
6364-
6365-
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
6366-
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6367-
6368-
*y = *x;
6354+
const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
6355+
6356+
if (mode == GGML_SCALE_MODE_NEAREST) {
6357+
for (int64_t i3 = 0; i3 < ne3; i3++) {
6358+
const int64_t i03 = i3 / sf3;
6359+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6360+
const int64_t i02 = i2 / sf2;
6361+
for (int64_t i1 = 0; i1 < ne1; i1++) {
6362+
const int64_t i01 = i1 / sf1;
6363+
for (int64_t i0 = 0; i0 < ne0; i0++) {
6364+
const int64_t i00 = i0 / sf0;
6365+
6366+
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
6367+
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6368+
6369+
*y = *x;
6370+
}
6371+
}
6372+
}
6373+
}
6374+
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
6375+
// setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True
6376+
const float pixel_offset = 0.5f;
6377+
6378+
for (int64_t i3 = 0; i3 < ne3; i3++) {
6379+
const int64_t i03 = i3 / sf3;
6380+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
6381+
const int64_t i02 = i2 / sf2;
6382+
for (int64_t i1 = 0; i1 < ne1; i1++) {
6383+
const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
6384+
int64_t y0 = (int64_t)floorf(y);
6385+
int64_t y1 = y0 + 1;
6386+
6387+
y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
6388+
y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
6389+
6390+
float dy = y - (float)y0;
6391+
dy = std::max(0.0f, std::min(dy, 1.0f));
6392+
6393+
for (int64_t i0 = 0; i0 < ne0; i0++) {
6394+
const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
6395+
int64_t x0 = (int64_t)floorf(x);
6396+
int64_t x1 = x0 + 1;
6397+
6398+
x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
6399+
x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
6400+
6401+
float dx = x - (float)x0;
6402+
dx = std::max(0.0f, std::min(dx, 1.0f));
6403+
6404+
// fetch the four surrounding pixel values and interpolate
6405+
const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
6406+
const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
6407+
const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
6408+
const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
6409+
6410+
const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
6411+
6412+
float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6413+
*y_dst = val;
6414+
}
63696415
}
63706416
}
63716417
}
6418+
} else {
6419+
GGML_ABORT("unsupported upscale mode");
63726420
}
63736421
}
63746422

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3216,6 +3216,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32163216
case GGML_OP_GROUP_NORM:
32173217
return ggml_is_contiguous(op->src[0]);
32183218
case GGML_OP_UPSCALE:
3219+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
32193220
case GGML_OP_PAD:
32203221
case GGML_OP_ARANGE:
32213222
case GGML_OP_TIMESTEP_EMBEDDING:

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13341334
return op->src[0]->type == GGML_TYPE_F16;
13351335
case GGML_OP_POOL_1D:
13361336
return false;
1337-
case GGML_OP_POOL_2D:
13381337
case GGML_OP_UPSCALE:
1338+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1339+
case GGML_OP_POOL_2D:
13391340
case GGML_OP_PAD:
13401341
case GGML_OP_PAD_REFLECT_1D:
13411342
case GGML_OP_TIMESTEP_EMBEDDING:

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4055,12 +4055,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
40554055
case GGML_OP_IM2COL:
40564056
// TODO: add support for the new F32 operations
40574057
return op->src[0]->type == GGML_TYPE_F16;
4058+
case GGML_OP_UPSCALE:
4059+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
40584060
case GGML_OP_POOL_2D:
40594061
case GGML_OP_SUM:
40604062
case GGML_OP_SUM_ROWS:
40614063
case GGML_OP_ARGSORT:
40624064
case GGML_OP_ACC:
4063-
case GGML_OP_UPSCALE:
40644065
case GGML_OP_PAD:
40654066
case GGML_OP_LEAKY_RELU:
40664067
case GGML_OP_TIMESTEP_EMBEDDING:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5749,7 +5749,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
57495749
}
57505750
return nullptr;
57515751
case GGML_OP_UPSCALE:
5752-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5752+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
57535753
return ctx->device->pipeline_upscale_f32;
57545754
}
57555755
return nullptr;
@@ -9404,9 +9404,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
94049404
case GGML_OP_COS:
94059405
case GGML_OP_CLAMP:
94069406
return op->src[0]->type == GGML_TYPE_F32;
9407+
case GGML_OP_UPSCALE:
9408+
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
94079409
case GGML_OP_ACC:
94089410
case GGML_OP_CONCAT:
9409-
case GGML_OP_UPSCALE:
94109411
case GGML_OP_SCALE:
94119412
case GGML_OP_PAD:
94129413
case GGML_OP_DIAG_MASK_INF:
@@ -9774,7 +9775,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
97749775
} else if (tensor->op == GGML_OP_CONCAT) {
97759776
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
97769777
} else if (tensor->op == GGML_OP_UPSCALE) {
9777-
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9778+
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
97789779
} else if (tensor->op == GGML_OP_SCALE) {
97799780
const float * params = (const float *)tensor->op_params;
97809781
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);

ggml/src/ggml.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,14 +4174,17 @@ static struct ggml_tensor * ggml_upscale_impl(
41744174
int ne0,
41754175
int ne1,
41764176
int ne2,
4177-
int ne3) {
4177+
int ne3,
4178+
enum ggml_scale_mode mode) {
41784179
GGML_ASSERT(a->ne[0] <= ne0);
41794180
GGML_ASSERT(a->ne[1] <= ne1);
41804181
GGML_ASSERT(a->ne[2] <= ne2);
41814182
GGML_ASSERT(a->ne[3] <= ne3);
41824183

41834184
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
41844185

4186+
ggml_set_op_params_i32(result, 0, mode);
4187+
41854188
result->op = GGML_OP_UPSCALE;
41864189
result->src[0] = a;
41874190

@@ -4191,8 +4194,9 @@ static struct ggml_tensor * ggml_upscale_impl(
41914194
struct ggml_tensor * ggml_upscale(
41924195
struct ggml_context * ctx,
41934196
struct ggml_tensor * a,
4194-
int scale_factor) {
4195-
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
4197+
int scale_factor,
4198+
enum ggml_scale_mode mode) {
4199+
return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode);
41964200
}
41974201

41984202
struct ggml_tensor * ggml_upscale_ext(
@@ -4201,8 +4205,9 @@ struct ggml_tensor * ggml_upscale_ext(
42014205
int ne0,
42024206
int ne1,
42034207
int ne2,
4204-
int ne3) {
4205-
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
4208+
int ne3,
4209+
enum ggml_scale_mode mode) {
4210+
return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode);
42064211
}
42074212

42084213
// ggml_pad

tests/test-backend-ops.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) {
271271
}
272272
}
273273

274+
static std::string var_to_str(ggml_scale_mode mode) {
275+
switch (mode) {
276+
case GGML_SCALE_MODE_NEAREST: return "nearest";
277+
case GGML_SCALE_MODE_BILINEAR: return "bilinear";
278+
default: return std::to_string(mode);
279+
}
280+
}
281+
274282
#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
275283

276284
#define VARS_TO_STR1(a) VAR_TO_STR(a)
@@ -2948,15 +2956,16 @@ struct test_upscale : public test_case {
29482956
const std::array<int64_t, 4> ne;
29492957
const int32_t scale_factor;
29502958
const bool transpose;
2959+
const ggml_scale_mode mode;
29512960

29522961
std::string vars() override {
2953-
return VARS_TO_STR4(type, ne, scale_factor, transpose);
2962+
return VARS_TO_STR5(type, ne, scale_factor, mode, transpose);
29542963
}
29552964

29562965
test_upscale(ggml_type type = GGML_TYPE_F32,
29572966
std::array<int64_t, 4> ne = {512, 512, 3, 1},
2958-
int32_t scale_factor = 2, bool transpose = false)
2959-
: type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
2967+
int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false)
2968+
: type(type), ne(ne), scale_factor(scale_factor), mode(mode), transpose(transpose) {}
29602969

29612970
ggml_tensor * build_graph(ggml_context * ctx) override {
29622971
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -2967,7 +2976,7 @@ struct test_upscale : public test_case {
29672976
ggml_set_name(a, "a_transposed");
29682977
}
29692978

2970-
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
2979+
ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode);
29712980
ggml_set_name(out, "out");
29722981

29732982
return out;
@@ -2979,21 +2988,23 @@ struct test_upscale_ext : public test_case {
29792988
const ggml_type type;
29802989
const std::array<int64_t, 4> ne;
29812990
const std::array<int64_t, 4> ne_tgt;
2991+
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
29822992

29832993
std::string vars() override {
2984-
return VARS_TO_STR3(type, ne, ne_tgt);
2994+
return VARS_TO_STR4(type, ne, ne_tgt, mode);
29852995
}
29862996

29872997
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
29882998
std::array<int64_t, 4> ne = {2, 5, 7, 11},
2989-
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
2990-
: type(type), ne(ne), ne_tgt(ne_tgt) {}
2999+
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
3000+
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
3001+
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
29913002

29923003
ggml_tensor * build_graph(ggml_context * ctx) override {
29933004
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
29943005
ggml_set_name(a, "a");
29953006

2996-
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
3007+
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
29973008
ggml_set_name(out, "out");
29983009

29993010
return out;
@@ -4399,12 +4410,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
43994410
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
44004411
}
44014412

4413+
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
4414+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
4415+
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
4416+
test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
4417+
}
4418+
44024419
test_cases.emplace_back(new test_sum());
44034420
test_cases.emplace_back(new test_sum_rows());
44044421
test_cases.emplace_back(new test_mean());
4405-
test_cases.emplace_back(new test_upscale());
4406-
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
4407-
test_cases.emplace_back(new test_upscale_ext());
44084422
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
44094423
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
44104424
test_cases.emplace_back(new test_acc());

0 commit comments

Comments
 (0)