@@ -425,7 +425,7 @@ struct vk_device_struct {
425425 vk_pipeline pipeline_div_norepeat[2][2][2];
426426
427427 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
428- vk_pipeline pipeline_upscale_f32 ;
428+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32 ;
429429 vk_pipeline pipeline_scale_f32;
430430 vk_pipeline pipeline_sqr_f32;
431431 vk_pipeline pipeline_sin_f32;
@@ -895,6 +895,7 @@ struct vk_op_conv2d_dw_push_constants {
895895
896896struct vk_op_upscale_push_constants {
897897 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
898+ uint32_t ne00; uint32_t ne01;
898899 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
899900 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
900901 float sf0; float sf1; float sf2; float sf3;
@@ -2856,7 +2857,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
28562857 ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28572858 ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
28582859
2859- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2860+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2861+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2862+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
28602863
28612864 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
28622865
@@ -6536,8 +6539,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
65366539 }
65376540 return nullptr;
65386541 case GGML_OP_UPSCALE:
6539- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6540- return ctx->device->pipeline_upscale_f32;
6542+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6543+ int mode = ggml_get_op_params_i32(dst, 0);
6544+ switch (mode) {
6545+ case GGML_SCALE_MODE_NEAREST:
6546+ return ctx->device->pipeline_upscale_nearest_f32;
6547+ case GGML_SCALE_MODE_BILINEAR:
6548+ return ctx->device->pipeline_upscale_bilinear_f32;
6549+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6550+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6551+ }
65416552 }
65426553 return nullptr;
65436554 case GGML_OP_SCALE:
@@ -7586,14 +7597,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
75867597
75877598static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
75887599 const uint32_t src0_type_size = ggml_type_size(src0->type);
7600+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
75897601
7590- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7591- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7592- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7593- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7602+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7603+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7604+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7605+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7606+
7607+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7608+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7609+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7610+ }
75947611
75957612 ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
75967613 (uint32_t)ggml_nelements(dst), 0, 0,
7614+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
75977615 (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
75987616 (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
75997617 sf0, sf1, sf2, sf3,
@@ -10578,13 +10596,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1057810596 case GGML_OP_CLAMP:
1057910597 return op->src[0]->type == GGML_TYPE_F32;
1058010598 case GGML_OP_UPSCALE:
10581- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
1058210599 case GGML_OP_ACC:
1058310600 case GGML_OP_CONCAT:
1058410601 case GGML_OP_SCALE:
1058510602 case GGML_OP_PAD:
10603+ case GGML_OP_ROLL:
1058610604 case GGML_OP_DIAG_MASK_INF:
10587- return true;
1058810605 case GGML_OP_SOFT_MAX:
1058910606 case GGML_OP_SOFT_MAX_BACK:
1059010607 case GGML_OP_ARGSORT:
0 commit comments