@@ -410,7 +410,7 @@ struct vk_device_struct {
410
410
vk_pipeline pipeline_div_norepeat[2][2][2];
411
411
412
412
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
413
- vk_pipeline pipeline_upscale_f32 ;
413
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32 ;
414
414
vk_pipeline pipeline_scale_f32;
415
415
vk_pipeline pipeline_sqr_f32;
416
416
vk_pipeline pipeline_sin_f32;
@@ -880,6 +880,7 @@ struct vk_op_conv2d_dw_push_constants {
880
880
881
881
struct vk_op_upscale_push_constants {
882
882
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
883
+ uint32_t ne00; uint32_t ne01;
883
884
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
884
885
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
885
886
float sf0; float sf1; float sf2; float sf3;
@@ -2773,7 +2774,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2773
2774
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2774
2775
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2775
2776
2776
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
2777
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
2778
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
2779
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
2777
2780
2778
2781
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2779
2782
@@ -6425,8 +6428,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6425
6428
}
6426
6429
return nullptr;
6427
6430
case GGML_OP_UPSCALE:
6428
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
6429
- return ctx->device->pipeline_upscale_f32;
6431
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6432
+ int mode = ggml_get_op_params_i32(dst, 0);
6433
+ switch (mode) {
6434
+ case GGML_SCALE_MODE_NEAREST:
6435
+ return ctx->device->pipeline_upscale_nearest_f32;
6436
+ case GGML_SCALE_MODE_BILINEAR:
6437
+ return ctx->device->pipeline_upscale_bilinear_f32;
6438
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
6439
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
6440
+ }
6430
6441
}
6431
6442
return nullptr;
6432
6443
case GGML_OP_SCALE:
@@ -7441,14 +7452,21 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co
7441
7452
7442
7453
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7443
7454
const uint32_t src0_type_size = ggml_type_size(src0->type);
7455
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
7444
7456
7445
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
7446
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
7447
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
7448
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
7457
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
7458
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
7459
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
7460
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
7461
+
7462
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7463
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
7464
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
7465
+ }
7449
7466
7450
7467
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
7451
7468
(uint32_t)ggml_nelements(dst), 0, 0,
7469
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
7452
7470
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7453
7471
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
7454
7472
sf0, sf1, sf2, sf3,
@@ -10346,7 +10364,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
10346
10364
case GGML_OP_CLAMP:
10347
10365
return op->src[0]->type == GGML_TYPE_F32;
10348
10366
case GGML_OP_UPSCALE:
10349
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
10350
10367
case GGML_OP_ACC:
10351
10368
case GGML_OP_CONCAT:
10352
10369
case GGML_OP_SCALE:
0 commit comments