diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 452c967b0a637..8fcc16df998be 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -507,17 +507,12 @@ extern "C" { GGML_OP_UNARY, - GGML_OP_MAP_UNARY, - GGML_OP_MAP_BINARY, - - GGML_OP_MAP_CUSTOM1_F32, - GGML_OP_MAP_CUSTOM2_F32, - GGML_OP_MAP_CUSTOM3_F32, - GGML_OP_MAP_CUSTOM1, GGML_OP_MAP_CUSTOM2, GGML_OP_MAP_CUSTOM3, + GGML_OP_CUSTOM, + GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, @@ -1722,24 +1717,29 @@ extern "C" { float p0, float p1); - // nearest interpolate + enum ggml_scale_mode { + GGML_SCALE_MODE_NEAREST = 0, + GGML_SCALE_MODE_BILINEAR = 1, + }; + + // interpolate // multiplies ne0 and ne1 by scale factor - // used in stable-diffusion GGML_API struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor); + int scale_factor, + enum ggml_scale_mode mode); - // nearest interpolate - // nearest interpolate to specified dimensions - // used in tortoise.cpp + // interpolate + // interpolate scale to specified dimensions GGML_API struct ggml_tensor * ggml_upscale_ext( struct ggml_context * ctx, struct ggml_tensor * a, int ne0, int ne1, int ne2, - int ne3); + int ne3, + enum ggml_scale_mode mode); // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0] GGML_API struct ggml_tensor * ggml_pad( @@ -1916,83 +1916,6 @@ extern "C" { // custom operators - typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); - typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); - - typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_unary_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_binary_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - ggml_custom1_op_f32_t fun), - "use ggml_map_custom1_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - ggml_custom2_op_f32_t fun), - "use ggml_map_custom2_inplace instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3 instead"); - - GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - ggml_custom3_op_f32_t fun), - "use ggml_map_custom3_inplace instead"); - - // custom operators v2 - typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); @@ -2048,6 +1971,30 @@ extern "C" { int n_tasks, void * userdata); + typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata); + // loss function GGML_API struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b513270c6e5ac..cec36b36e7e92 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1824,6 +1824,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { return false; } + if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { + return false; + } return true; } case GGML_OP_POOL_2D: { diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 34618c27aa475..50400328738ef 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2027,41 +2027,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rwkv_wkv7(params, tensor); } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2077,6 +2042,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_map_custom3(params, tensor); } break; + case GGML_OP_CUSTOM: + { + ggml_compute_forward_custom(params, tensor); + } + break; case GGML_OP_CROSS_ENTROPY_LOSS: { ggml_compute_forward_cross_entropy_loss(params, tensor); @@ -2328,11 +2298,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_GET_REL_POS: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: { n_tasks = 1; } break; @@ -2366,6 +2331,16 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = MIN(p.n_tasks, n_threads); } } break; + case GGML_OP_CUSTOM: + { + struct ggml_custom_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f63656be54f5c..6050147be70ac 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6351,24 +6351,72 @@ static void ggml_compute_forward_upscale_f32( const float sf2 = (float)ne2/src0->ne[2]; const float sf3 = (float)ne3/src0->ne[3]; - // TODO: optimize - - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3 / sf3; - for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2 / sf2; - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / sf1; - for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / sf0; - - const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + + if (mode == GGML_SCALE_MODE_NEAREST) { + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / sf0; + + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + // setting a pixel offset of 0 would replicate the behavior of pytorch interpolate with align_corners=True + const float pixel_offset = 0.5f; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset; + int64_t y0 = (int64_t)floorf(y); + int64_t y1 = y0 + 1; + + y0 = std::max(int64_t(0), std::min(y0, ne01 - 1)); + y1 = std::max(int64_t(0), std::min(y1, ne01 - 1)); + + float dy = y - (float)y0; + dy = std::max(0.0f, std::min(dy, 1.0f)); + + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset; + int64_t x0 = (int64_t)floorf(x); + int64_t x1 = x0 + 1; + + x0 = std::max(int64_t(0), std::min(x0, ne00 - 1)); + x1 = std::max(int64_t(0), std::min(x1, ne00 - 1)); + + float dx = x - (float)x0; + dx = std::max(0.0f, std::min(dx, 1.0f)); + + // fetch the four surrounding pixel values and interpolate + const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03); + const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03); + + const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy; + + float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *y_dst = val; + } } } } + } else { + GGML_ABORT("unsupported upscale mode"); } } @@ -8268,152 +8316,6 @@ void ggml_compute_forward_rwkv_wkv7( } } -// ggml_compute_forward_map_unary - -static void ggml_compute_forward_map_unary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -void ggml_compute_forward_map_unary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_unary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_binary - -static void ggml_compute_forward_map_binary_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(src1)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -void ggml_compute_forward_map_binary( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_binary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_custom1 - -void ggml_compute_forward_map_custom1_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - - if (params->ith != 0) { - return; - } - - fun(dst, a); -} - -// ggml_compute_forward_map_custom2 - -void ggml_compute_forward_map_custom2_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b); -} - -// ggml_compute_forward_map_custom3 - -void ggml_compute_forward_map_custom3_f32( - const ggml_compute_params * params, - ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { - - const ggml_tensor * a = dst->src[0]; - const ggml_tensor * b = dst->src[1]; - const ggml_tensor * c = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b, c); -} - // ggml_compute_forward_map_custom1 void ggml_compute_forward_map_custom1( @@ -8459,6 +8361,18 @@ void ggml_compute_forward_map_custom3( p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); } +// ggml_compute_forward_custom + +void ggml_compute_forward_custom( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + struct ggml_custom_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, params->ith, params->nth, p.userdata); +} + // ggml_compute_forward_cross_entropy_loss static void ggml_compute_forward_cross_entropy_loss_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index d43fbc1fc472a..410a372047a01 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -96,29 +96,10 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); -void ggml_compute_forward_map_unary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun); -void ggml_compute_forward_map_binary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun); -void ggml_compute_forward_map_custom1_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun); -void ggml_compute_forward_map_custom2_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun); -void ggml_compute_forward_map_custom3_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_custom(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 633456a92d0de..fafe9633e2027 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3216,6 +3216,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index caa6b9dba3f06..a19cfb14e0f9f 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -16,6 +16,14 @@ #include #endif // __ARM_FEATURE_SVE +#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__) +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include +#endif + #if defined(__F16C__) #include #endif @@ -140,8 +148,14 @@ struct ggml_map_custom2_op_params { struct ggml_map_custom3_op_params { ggml_custom3_op_t fun; - int n_tasks; - void * userdata; + int n_tasks; + void * userdata; +}; + +struct ggml_custom_op_params { + ggml_custom_op_t fun; + int n_tasks; + void * userdata; }; // bitset @@ -311,13 +325,6 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); // for MUSA compilers , we use uint16_t: ref https://github.com/ggml-org/llama.cpp/pull/11843 // #if defined(__ARM_NEON) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) && !defined(__MUSACC__) - - // if YCM cannot find , make a symbolic link to it, for example: - // - // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ - // - #include - #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f226826020a5a..9f1c6c6ccc09f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1334,8 +1334,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: return false; - case GGML_OP_POOL_2D: case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_POOL_2D: case GGML_OP_PAD: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 89715eaea0753..e6f1603d84e07 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4055,12 +4055,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: // TODO: add support for the new F32 operations return op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: - case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e69d00ad54978..783a0ff86c1c1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5749,7 +5749,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { return ctx->device->pipeline_upscale_f32; } return nullptr; @@ -9404,9 +9404,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COS: case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_ACC: case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: @@ -9774,7 +9775,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3e274d6ae3961..950772c75cb32 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -982,23 +982,18 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "UNARY", - "MAP_UNARY", - "MAP_BINARY", - - "MAP_CUSTOM1_F32", - "MAP_CUSTOM2_F32", - "MAP_CUSTOM3_F32", - "MAP_CUSTOM1", "MAP_CUSTOM2", "MAP_CUSTOM3", + "CUSTOM", + "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1081,23 +1076,18 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "unary(x)", - "f(x)", - "f(x,y)", - - "custom_f32(x)", - "custom_f32(x,y)", - "custom_f32(x,y,z)", + "map_custom(x)", + "map_custom(x,y)", + "map_custom(x,y,z)", "custom(x)", - "custom(x,y)", - "custom(x,y,z)", "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", }; -static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4184,7 +4174,8 @@ static struct ggml_tensor * ggml_upscale_impl( int ne0, int ne1, int ne2, - int ne3) { + int ne3, + enum ggml_scale_mode mode) { GGML_ASSERT(a->ne[0] <= ne0); GGML_ASSERT(a->ne[1] <= ne1); GGML_ASSERT(a->ne[2] <= ne2); @@ -4192,6 +4183,8 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + ggml_set_op_params_i32(result, 0, mode); + result->op = GGML_OP_UPSCALE; result->src[0] = a; @@ -4201,8 +4194,9 @@ static struct ggml_tensor * ggml_upscale_impl( struct ggml_tensor * ggml_upscale( struct ggml_context * ctx, struct ggml_tensor * a, - int scale_factor) { - return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); + int scale_factor, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3], mode); } struct ggml_tensor * ggml_upscale_ext( @@ -4211,8 +4205,9 @@ struct ggml_tensor * ggml_upscale_ext( int ne0, int ne1, int ne2, - int ne3) { - return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); + int ne3, + enum ggml_scale_mode mode) { + return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3, mode); } // ggml_pad @@ -4842,179 +4837,6 @@ struct ggml_tensor * ggml_unary_inplace( return ggml_unary_impl(ctx, a, op, true); } -// ggml_map_unary - -static struct ggml_tensor * ggml_map_unary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_UNARY; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_unary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_unary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { - return ggml_map_unary_impl_f32(ctx, a, fun, true); -} - -// ggml_map_binary - -static struct ggml_tensor * ggml_map_binary_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_BINARY; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_binary_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_binary_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { - return ggml_map_binary_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom1_f32 - -static struct ggml_tensor * ggml_map_custom1_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->src[0] = a; - - return result; -} - -struct ggml_tensor * ggml_map_custom1_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, false); -} - -struct ggml_tensor * ggml_map_custom1_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_f32_t fun) { - return ggml_map_custom1_impl_f32(ctx, a, fun, true); -} - -// ggml_map_custom2_f32 - -static struct ggml_tensor * ggml_map_custom2_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_map_custom2_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, false); -} - -struct ggml_tensor * ggml_map_custom2_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_f32_t fun) { - return ggml_map_custom2_impl_f32(ctx, a, b, fun, true); -} - -// ggml_map_custom3_f32 - -static struct ggml_tensor * ggml_map_custom3_impl_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun, - bool inplace) { - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - - return result; -} - -struct ggml_tensor * ggml_map_custom3_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, false); -} - -struct ggml_tensor * ggml_map_custom3_inplace_f32( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_f32_t fun) { - return ggml_map_custom3_impl_f32(ctx, a, b, c, fun, true); -} - // ggml_map_custom1 static struct ggml_tensor * ggml_map_custom1_impl( @@ -5033,7 +4855,7 @@ static struct ggml_tensor * ggml_map_custom1_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM1; result->src[0] = a; @@ -5078,7 +4900,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM2; result->src[0] = a; @@ -5127,7 +4949,7 @@ static struct ggml_tensor * ggml_map_custom3_impl( /*.n_tasks =*/ n_tasks, /*.userdata =*/ userdata }; - ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_MAP_CUSTOM3; result->src[0] = a; @@ -5159,6 +4981,66 @@ struct ggml_tensor * ggml_map_custom3_inplace( return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } +struct ggml_tensor * ggml_custom_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + for (int i = 0; i < n_args; i++) { + result->src[i] = args[i]; + } + + return result; +} + +struct ggml_tensor * ggml_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor ** args, + int n_args, + ggml_custom_op_t fun, + int n_tasks, + void * userdata) { + + GGML_ASSERT(n_args < GGML_MAX_SRC - 1); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + struct ggml_custom_op_params params = { + /*.fun =*/ fun, + /*.n_tasks =*/ n_tasks, + /*.userdata =*/ userdata + }; + ggml_set_op_params(result, ¶ms, sizeof(params)); + + result->op = GGML_OP_CUSTOM; + result->src[0] = a; + for (int i = 0; i < n_args; i++) { + result->src[i + 1] = args[i]; + } + + return result; +} // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 914ff7c55356f..204354209f2d6 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -158,13 +158,13 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # scripts/gen-authors.sh -> scripts/gen-authors.sh cat ggml-src.patch | sed -E \ - -e 's/(^[[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \ - -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ - -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ + -e 's/([[:space:]]| [ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/BuildTypes.cmake/\1ggml\/cmake\/BuildTypes.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/GitVars.cmake/\1ggml\/cmake\/GitVars.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/common.cmake/\1ggml\/cmake\/common.cmake/g' \ + -e 's/([[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ + -e 's/([[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ @@ -180,11 +180,11 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]| [ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ - -e 's/^([[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ - -e 's/^([[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ - -e 's/^([[:space:]]| [ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \ - -e 's/^([[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \ - -e 's/^([[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ + -e 's/([[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ + -e 's/([[:space:]]| [ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \ + -e 's/([[:space:]]| [ab]\/)LICENSE/\1LICENSE/g' \ + -e 's/([[:space:]]| [ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ > ggml-src.patch.tmp mv ggml-src.patch.tmp ggml-src.patch diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index e096778bfda55..7111936baabc8 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -70e85f61f1fdcd1064a1e032ff564d5b5e67560c +2abf606f098844faebee578996cae9c6d63a40e2 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e61a126cf5b2f..3a5741c8d959d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -271,6 +271,14 @@ static std::string var_to_str(ggml_op_pool pool) { } } +static std::string var_to_str(ggml_scale_mode mode) { + switch (mode) { + case GGML_SCALE_MODE_NEAREST: return "nearest"; + case GGML_SCALE_MODE_BILINEAR: return "bilinear"; + default: return std::to_string(mode); + } +} + #define VAR_TO_STR(x) (#x "=" + var_to_str(x)) #define VARS_TO_STR1(a) VAR_TO_STR(a) @@ -2948,15 +2956,16 @@ struct test_upscale : public test_case { const std::array ne; const int32_t scale_factor; const bool transpose; + const ggml_scale_mode mode; std::string vars() override { - return VARS_TO_STR4(type, ne, scale_factor, transpose); + return VARS_TO_STR5(type, ne, scale_factor, mode, transpose); } test_upscale(ggml_type type = GGML_TYPE_F32, std::array ne = {512, 512, 3, 1}, - int32_t scale_factor = 2, bool transpose = false) - : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {} + int32_t scale_factor = 2, ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST, bool transpose = false) + : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose), mode(mode) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2967,7 +2976,7 @@ struct test_upscale : public test_case { ggml_set_name(a, "a_transposed"); } - ggml_tensor * out = ggml_upscale(ctx, a, scale_factor); + ggml_tensor * out = ggml_upscale(ctx, a, scale_factor, mode); ggml_set_name(out, "out"); return out; @@ -2979,21 +2988,23 @@ struct test_upscale_ext : public test_case { const ggml_type type; const std::array ne; const std::array ne_tgt; + const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST; std::string vars() override { - return VARS_TO_STR3(type, ne, ne_tgt); + return VARS_TO_STR4(type, ne, ne_tgt, mode); } test_upscale_ext(ggml_type type = GGML_TYPE_F32, std::array ne = {2, 5, 7, 11}, - std::array ne_tgt = {5, 7, 11, 13}) - : type(type), ne(ne), ne_tgt(ne_tgt) {} + std::array ne_tgt = {5, 7, 11, 13}, + ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST) + : type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]); + ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode); ggml_set_name(out, "out"); return out; @@ -4399,12 +4410,15 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen } + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { + test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); + test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); + test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); + } + test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_mean()); - test_cases.emplace_back(new test_upscale()); - test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); - test_cases.emplace_back(new test_upscale_ext()); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_acc());