@@ -464,6 +464,7 @@ struct vk_device_struct {
464
464
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
465
465
vk_pipeline pipeline_scale_f32;
466
466
vk_pipeline pipeline_sqr_f32;
467
+ vk_pipeline pipeline_sqrt_f32;
467
468
vk_pipeline pipeline_sin_f32;
468
469
vk_pipeline pipeline_cos_f32;
469
470
vk_pipeline pipeline_clamp_f32;
@@ -3031,6 +3032,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3031
3032
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3032
3033
3033
3034
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3035
+ ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3034
3036
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3035
3037
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3036
3038
@@ -6981,6 +6983,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
6981
6983
return ctx->device->pipeline_sqr_f32;
6982
6984
}
6983
6985
return nullptr;
6986
+ case GGML_OP_SQRT:
6987
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6988
+ return ctx->device->pipeline_sqrt_f32;
6989
+ }
6990
+ return nullptr;
6984
6991
case GGML_OP_SIN:
6985
6992
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6986
6993
return ctx->device->pipeline_sin_f32;
@@ -7290,6 +7297,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
7290
7297
case GGML_OP_CONCAT:
7291
7298
case GGML_OP_UPSCALE:
7292
7299
case GGML_OP_SQR:
7300
+ case GGML_OP_SQRT:
7293
7301
case GGML_OP_SIN:
7294
7302
case GGML_OP_COS:
7295
7303
case GGML_OP_CLAMP:
@@ -7595,6 +7603,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
7595
7603
case GGML_OP_MUL:
7596
7604
case GGML_OP_SCALE:
7597
7605
case GGML_OP_SQR:
7606
+ case GGML_OP_SQRT:
7598
7607
case GGML_OP_SIN:
7599
7608
case GGML_OP_COS:
7600
7609
case GGML_OP_CLAMP:
@@ -8242,6 +8251,10 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
8242
8251
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
8243
8252
}
8244
8253
8254
+ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8255
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
8256
+ }
8257
+
8245
8258
static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8246
8259
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
8247
8260
}
@@ -9697,6 +9710,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9697
9710
case GGML_OP_UPSCALE:
9698
9711
case GGML_OP_SCALE:
9699
9712
case GGML_OP_SQR:
9713
+ case GGML_OP_SQRT:
9700
9714
case GGML_OP_SIN:
9701
9715
case GGML_OP_COS:
9702
9716
case GGML_OP_CLAMP:
@@ -9766,6 +9780,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9766
9780
case GGML_OP_UPSCALE:
9767
9781
case GGML_OP_SCALE:
9768
9782
case GGML_OP_SQR:
9783
+ case GGML_OP_SQRT:
9769
9784
case GGML_OP_SIN:
9770
9785
case GGML_OP_COS:
9771
9786
case GGML_OP_CLAMP:
@@ -9867,6 +9882,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
9867
9882
case GGML_OP_SQR:
9868
9883
ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
9869
9884
9885
+ break;
9886
+ case GGML_OP_SQRT:
9887
+ ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun);
9888
+
9870
9889
break;
9871
9890
case GGML_OP_SIN:
9872
9891
ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
@@ -10118,6 +10137,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
10118
10137
case GGML_OP_UPSCALE:
10119
10138
case GGML_OP_SCALE:
10120
10139
case GGML_OP_SQR:
10140
+ case GGML_OP_SQRT:
10121
10141
case GGML_OP_SIN:
10122
10142
case GGML_OP_COS:
10123
10143
case GGML_OP_CLAMP:
@@ -11357,6 +11377,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
11357
11377
case GGML_OP_SILU_BACK:
11358
11378
case GGML_OP_RMS_NORM_BACK:
11359
11379
case GGML_OP_SQR:
11380
+ case GGML_OP_SQRT:
11360
11381
case GGML_OP_SIN:
11361
11382
case GGML_OP_COS:
11362
11383
case GGML_OP_CLAMP:
@@ -11801,6 +11822,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
11801
11822
tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
11802
11823
} else if (tensor->op == GGML_OP_SQR) {
11803
11824
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
11825
+ } else if (tensor->op == GGML_OP_SQRT) {
11826
+ tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
11804
11827
} else if (tensor->op == GGML_OP_SIN) {
11805
11828
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
11806
11829
} else if (tensor->op == GGML_OP_COS) {
0 commit comments