@@ -361,6 +361,9 @@ enum vk_conv_shapes {
361361 CONV_SHAPE_COUNT,
362362};
363363
364+ static constexpr uint32_t num_argsort_pipelines = 11;
365+ static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
366+
364367struct vk_device_struct {
365368 std::recursive_mutex mutex;
366369
@@ -477,6 +480,7 @@ struct vk_device_struct {
477480 vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
478481 vk_pipeline pipeline_scale_f32;
479482 vk_pipeline pipeline_sqr_f32;
483+ vk_pipeline pipeline_sqrt_f32;
480484 vk_pipeline pipeline_sin_f32;
481485 vk_pipeline pipeline_cos_f32;
482486 vk_pipeline pipeline_clamp_f32;
@@ -521,7 +525,7 @@ struct vk_device_struct {
521525 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
522526 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
523527 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
524- vk_pipeline pipeline_argsort_f32;
528+ vk_pipeline pipeline_argsort_f32[num_argsort_pipelines] ;
525529 vk_pipeline pipeline_sum_rows_f32;
526530 vk_pipeline pipeline_argmax_f32;
527531 vk_pipeline pipeline_count_equal_i32;
@@ -886,7 +890,6 @@ struct vk_op_soft_max_push_constants {
886890
887891struct vk_op_argsort_push_constants {
888892 uint32_t ncols;
889- uint32_t ncols_pad;
890893 int32_t order;
891894};
892895
@@ -3045,6 +3048,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
30453048 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
30463049
30473050 ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3051+ ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
30483052 ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
30493053 ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
30503054
@@ -3115,7 +3119,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
31153119 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
31163120 }
31173121
3118- ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
3122+ for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
3123+ ggml_vk_create_pipeline(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<<i, 1, 1}, {1u<<i, i}, 1, true);
3124+ }
31193125
31203126 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
31213127
@@ -7007,6 +7013,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
70077013 return ctx->device->pipeline_sqr_f32;
70087014 }
70097015 return nullptr;
7016+ case GGML_OP_SQRT:
7017+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
7018+ return ctx->device->pipeline_sqrt_f32;
7019+ }
7020+ return nullptr;
70107021 case GGML_OP_SIN:
70117022 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
70127023 return ctx->device->pipeline_sin_f32;
@@ -7190,7 +7201,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
71907201 }
71917202 case GGML_OP_ARGSORT:
71927203 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
7193- return ctx->device->pipeline_argsort_f32;
7204+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
7205+ return ctx->device->pipeline_argsort_f32[idx];
71947206 }
71957207 return nullptr;
71967208 case GGML_OP_SUM:
@@ -7315,6 +7327,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
73157327 case GGML_OP_CONCAT:
73167328 case GGML_OP_UPSCALE:
73177329 case GGML_OP_SQR:
7330+ case GGML_OP_SQRT:
73187331 case GGML_OP_SIN:
73197332 case GGML_OP_COS:
73207333 case GGML_OP_CLAMP:
@@ -7620,6 +7633,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
76207633 case GGML_OP_MUL:
76217634 case GGML_OP_SCALE:
76227635 case GGML_OP_SQR:
7636+ case GGML_OP_SQRT:
76237637 case GGML_OP_SIN:
76247638 case GGML_OP_COS:
76257639 case GGML_OP_CLAMP:
@@ -8267,6 +8281,10 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
82678281 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun);
82688282}
82698283
8284+ static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
8285+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun);
8286+ }
8287+
82708288static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
82718289 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun);
82728290}
@@ -8515,16 +8533,8 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
85158533
85168534 uint32_t ncols = src0->ne[0];
85178535
8518- uint32_t ncols_pad = 1;
8519- while (ncols_pad < ncols) {
8520- ncols_pad *= 2;
8521- }
8522-
8523- GGML_ASSERT(ncols_pad <= 1024);
8524-
85258536 ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
85268537 ncols,
8527- ncols_pad,
85288538 op_params[0],
85298539 }, dryrun);
85308540}
@@ -9730,6 +9740,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97309740 case GGML_OP_UPSCALE:
97319741 case GGML_OP_SCALE:
97329742 case GGML_OP_SQR:
9743+ case GGML_OP_SQRT:
97339744 case GGML_OP_SIN:
97349745 case GGML_OP_COS:
97359746 case GGML_OP_CLAMP:
@@ -9799,6 +9810,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
97999810 case GGML_OP_UPSCALE:
98009811 case GGML_OP_SCALE:
98019812 case GGML_OP_SQR:
9813+ case GGML_OP_SQRT:
98029814 case GGML_OP_SIN:
98039815 case GGML_OP_COS:
98049816 case GGML_OP_CLAMP:
@@ -9900,6 +9912,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
99009912 case GGML_OP_SQR:
99019913 ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun);
99029914
9915+ break;
9916+ case GGML_OP_SQRT:
9917+ ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun);
9918+
99039919 break;
99049920 case GGML_OP_SIN:
99059921 ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun);
@@ -10151,6 +10167,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1015110167 case GGML_OP_UPSCALE:
1015210168 case GGML_OP_SCALE:
1015310169 case GGML_OP_SQR:
10170+ case GGML_OP_SQRT:
1015410171 case GGML_OP_SIN:
1015510172 case GGML_OP_COS:
1015610173 case GGML_OP_CLAMP:
@@ -11390,13 +11407,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1139011407 case GGML_OP_SILU_BACK:
1139111408 case GGML_OP_RMS_NORM_BACK:
1139211409 case GGML_OP_SQR:
11410+ case GGML_OP_SQRT:
1139311411 case GGML_OP_SIN:
1139411412 case GGML_OP_COS:
1139511413 case GGML_OP_CLAMP:
1139611414 case GGML_OP_LEAKY_RELU:
1139711415 case GGML_OP_OPT_STEP_ADAMW:
1139811416 case GGML_OP_OPT_STEP_SGD:
1139911417 return op->src[0]->type == GGML_TYPE_F32;
11418+ case GGML_OP_ARGSORT:
11419+ return op->ne[0] <= max_argsort_cols;
1140011420 case GGML_OP_UPSCALE:
1140111421 case GGML_OP_ACC:
1140211422 case GGML_OP_CONCAT:
@@ -11406,7 +11426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1140611426 case GGML_OP_DIAG_MASK_INF:
1140711427 case GGML_OP_SOFT_MAX:
1140811428 case GGML_OP_SOFT_MAX_BACK:
11409- case GGML_OP_ARGSORT:
1141011429 case GGML_OP_SUM:
1141111430 case GGML_OP_SUM_ROWS:
1141211431 case GGML_OP_ARGMAX:
@@ -11833,6 +11852,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1183311852 tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
1183411853 } else if (tensor->op == GGML_OP_SQR) {
1183511854 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
11855+ } else if (tensor->op == GGML_OP_SQRT) {
11856+ tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
1183611857 } else if (tensor->op == GGML_OP_SIN) {
1183711858 tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
1183811859 } else if (tensor->op == GGML_OP_COS) {
0 commit comments