@@ -425,7 +425,7 @@ struct vk_device_struct {
425425 vk_pipeline pipeline_div_norepeat[2][2][2];
426426
427427 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
428- vk_pipeline pipeline_upscale_f32 ;
428+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32 ;
429429 vk_pipeline pipeline_scale_f32;
430430 vk_pipeline pipeline_sqr_f32;
431431 vk_pipeline pipeline_sin_f32;
@@ -894,6 +894,7 @@ struct vk_op_conv2d_dw_push_constants {
894894
895895struct vk_op_upscale_push_constants {
896896 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
897+ uint32_t ne00; uint32_t ne01;
897898 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
898899 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
899900 float sf0; float sf1; float sf2; float sf3;
@@ -2822,7 +2823,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
28222823 ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28232824 ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28242825
2825- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2826+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2827+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2828+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
28262829
28272830 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28282831
@@ -6502,8 +6505,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65026505 }
65036506 return nullptr;
65046507 case GGML_OP_UPSCALE:
6505- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6506- return ctx->device->pipeline_upscale_f32;
6508+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6509+ int mode = ggml_get_op_params_i32(dst, 0);
6510+ switch (mode) {
6511+ case GGML_SCALE_MODE_NEAREST:
6512+ return ctx->device->pipeline_upscale_nearest_f32;
6513+ case GGML_SCALE_MODE_BILINEAR:
6514+ return ctx->device->pipeline_upscale_bilinear_f32;
6515+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6516+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6517+ }
65076518 }
65086519 return nullptr;
65096520 case GGML_OP_SCALE:
@@ -7524,14 +7535,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
75247535
75257536static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75267537 const uint32_t src0_type_size = ggml_type_size(src0->type);
7538+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
75277539
7528- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7529- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7530- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7531- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7540+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7541+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7542+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7543+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7544+
7545+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7546+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7547+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7548+ }
75327549
75337550 ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
75347551 (uint32_t)ggml_nelements(dst), 0, 0,
7552+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
75357553 (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
75367554 (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
75377555 sf0, sf1, sf2, sf3,
@@ -10483,13 +10501,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1048310501 case GGML_OP_CLAMP:
1048410502 return op->src[0]->type == GGML_TYPE_F32;
1048510503 case GGML_OP_UPSCALE:
10486- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1048710504 case GGML_OP_ACC:
1048810505 case GGML_OP_CONCAT:
1048910506 case GGML_OP_SCALE:
1049010507 case GGML_OP_PAD:
10508+ case GGML_OP_ROLL:
1049110509 case GGML_OP_DIAG_MASK_INF:
10492- return true;
1049310510 case GGML_OP_SOFT_MAX:
1049410511 case GGML_OP_SOFT_MAX_BACK:
1049510512 case GGML_OP_ARGSORT:
0 commit comments