@@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants {
10561056 uint32_t s1;
10571057 uint32_t s2;
10581058 int32_t sections[4];
1059+ uint32_t is_imrope;
10591060 uint32_t is_back;
10601061 uint32_t set_rows_stride;
10611062};
@@ -1082,6 +1083,7 @@ struct vk_op_soft_max_push_constants {
10821083
10831084struct vk_op_argsort_push_constants {
10841085 uint32_t ncols;
1086+ uint32_t nrows;
10851087 int32_t order;
10861088};
10871089
@@ -8708,6 +8710,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
87088710 break;
87098711 case GGML_OP_ARGSORT:
87108712 elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
8713+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
87118714 break;
87128715 case GGML_OP_IM2COL:
87138716 {
@@ -9925,6 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
99259928 memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
99269929 }
99279930
9931+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
9932+
99289933 float corr_dims[2];
99299934 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
99309935
@@ -9946,17 +9951,19 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
99469951 (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
99479952 freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
99489953 src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
9949- { sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
9954+ { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
99509955 }, dryrun);
99519956}
99529957
99539958static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
99549959 int32_t * op_params = (int32_t *)dst->op_params;
99559960
99569961 uint32_t ncols = src0->ne[0];
9962+ uint32_t nrows = ggml_nrows(src0);
99579963
99589964 ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
99599965 ncols,
9966+ nrows,
99609967 op_params[0],
99619968 }, dryrun);
99629969}
0 commit comments