@@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state {
399399 }
400400};
401401
402+ struct vk_solve_tri_pipeline_state {
403+ vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
404+ : N(N), K(K) {}
405+
406+ uint32_t N, K;
407+
408+ bool operator<(const vk_solve_tri_pipeline_state &b) const {
409+ return std::tie(N, K) <
410+ std::tie(b.N, b.K);
411+ }
412+ };
413+
402414enum shader_reduction_mode {
403415 SHADER_REDUCTION_MODE_SHMEM,
404416 SHADER_REDUCTION_MODE_HYBRID,
@@ -711,6 +723,7 @@ struct vk_device_struct {
711723 vk_pipeline pipeline_cumsum_f32;
712724 vk_pipeline pipeline_argmax_f32;
713725 vk_pipeline pipeline_count_equal_i32;
726+ std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
714727 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
715728 vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
716729 vk_pipeline pipeline_timestep_embedding_f32;
@@ -4002,6 +4015,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
40024015
40034016 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
40044017
4018+ for (auto &s : device->pipeline_solve_tri_f32) {
4019+ const vk_solve_tri_pipeline_state &state = s.first;
4020+ ggml_vk_create_pipeline(
4021+ device, s.second, "solve_tri_f32",
4022+ solve_tri_f32_len, solve_tri_f32_data, "main", 3,
4023+ sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
4024+ }
4025+
40054026#define IM2COL(bda) \
40064027 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
40074028 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
@@ -8496,6 +8517,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84968517 return ctx->device->pipeline_cumsum_f32;
84978518 }
84988519 return nullptr;
8520+ case GGML_OP_SOLVE_TRI:
8521+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8522+
8523+ vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
8524+
8525+ vk_pipeline pipeline = nullptr;
8526+
8527+ {
8528+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8529+ auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
8530+ if (it != ctx->device->pipeline_solve_tri_f32.end()) {
8531+ pipeline = it->second;
8532+ } else {
8533+ ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
8534+ }
8535+ }
8536+
8537+ return pipeline;
8538+ }
8539+ return nullptr;
84998540 case GGML_OP_ARGMAX:
85008541 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
85018542 return ctx->device->pipeline_argmax_f32;
@@ -8832,6 +8873,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88328873 elements = { nr, 1, 1 };
88338874 }
88348875 } break;
8876+ case GGML_OP_SOLVE_TRI:
8877+ {
8878+ uint32_t nr = (uint32_t)(ne02 * ne03);
8879+ if (nr > 262144) {
8880+ elements = { 512, 512, CEIL_DIV(nr, 262144) };
8881+ } else if (nr > 512) {
8882+ elements = { 512, CEIL_DIV(nr, 512), 1 };
8883+ } else {
8884+ elements = { nr, 1, 1 };
8885+ }
8886+ }
8887+ break;
88358888 case GGML_OP_RMS_NORM:
88368889 if (ctx->do_add_rms_partials) {
88378890 // Run one element per thread, 128 threads per workgroup
@@ -10260,6 +10313,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
1026010313 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
1026110314}
1026210315
10316+ static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10317+ const uint32_t src0_type_size = ggml_type_size(src0->type);
10318+ const uint32_t src1_type_size = ggml_type_size(src1->type);
10319+ const uint32_t dst_type_size = ggml_type_size(dst->type);
10320+
10321+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
10322+ (uint32_t)ggml_nelements(src0),
10323+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10324+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10325+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10326+ 0,
10327+ 0.0f, 0.0f, 0,
10328+ });
10329+ }
10330+
1026310331static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1026410332 const int32_t s0 = dst->op_params[0];
1026510333 const int32_t s1 = dst->op_params[1];
@@ -11871,6 +11939,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1187111939 case GGML_OP_COUNT_EQUAL:
1187211940 ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
1187311941
11942+ break;
11943+ case GGML_OP_SOLVE_TRI:
11944+ ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
11945+
1187411946 break;
1187511947 case GGML_OP_IM2COL:
1187611948 ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
@@ -13916,6 +13988,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391613988 }
1391713989 return false;
1391813990 }
13991+ case GGML_OP_SOLVE_TRI:
13992+ {
13993+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13994+ const vk_device& device = ggml_vk_get_device(ctx->device);
13995+
13996+ if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
13997+ return false;
13998+ }
13999+ const uint32_t N = op->src[0]->ne[0];
14000+ const uint32_t K = op->src[1]->ne[0];
14001+ // K dimension limited to workgroup size
14002+ if (K > 128) {
14003+ return false;
14004+ }
14005+ if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
14006+ return false;
14007+ }
14008+ return true;
14009+ }
1391914010 case GGML_OP_ARGMAX:
1392014011 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392114012 case GGML_OP_COUNT_EQUAL:
@@ -14588,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1458814679 tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
1458914680 } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
1459014681 tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
14682+ } else if (tensor->op == GGML_OP_SOLVE_TRI) {
14683+ tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
1459114684 } else if (tensor->op == GGML_OP_IM2COL) {
1459214685 const int32_t s0 = tensor->op_params[0];
1459314686 const int32_t s1 = tensor->op_params[1];
0 commit comments