Skip to content

Commit 4abef75

Browse files
authored
vulkan: Implement SOLVE_TRI (#17486)
* vulkan: Implement SOLVE_TRI * load B matrix through shared memory * use FLOAT_TYPE
1 parent c386114 commit 4abef75

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state {
399399
}
400400
};
401401

402+
struct vk_solve_tri_pipeline_state {
403+
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
404+
: N(N), K(K) {}
405+
406+
uint32_t N, K;
407+
408+
bool operator<(const vk_solve_tri_pipeline_state &b) const {
409+
return std::tie(N, K) <
410+
std::tie(b.N, b.K);
411+
}
412+
};
413+
402414
enum shader_reduction_mode {
403415
SHADER_REDUCTION_MODE_SHMEM,
404416
SHADER_REDUCTION_MODE_HYBRID,
@@ -711,6 +723,7 @@ struct vk_device_struct {
711723
vk_pipeline pipeline_cumsum_f32;
712724
vk_pipeline pipeline_argmax_f32;
713725
vk_pipeline pipeline_count_equal_i32;
726+
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
714727
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
715728
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
716729
vk_pipeline pipeline_timestep_embedding_f32;
@@ -4002,6 +4015,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
40024015

40034016
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
40044017

4018+
for (auto &s : device->pipeline_solve_tri_f32) {
4019+
const vk_solve_tri_pipeline_state &state = s.first;
4020+
ggml_vk_create_pipeline(
4021+
device, s.second, "solve_tri_f32",
4022+
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
4023+
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
4024+
}
4025+
40054026
#define IM2COL(bda) \
40064027
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
40074028
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
@@ -8496,6 +8517,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84968517
return ctx->device->pipeline_cumsum_f32;
84978518
}
84988519
return nullptr;
8520+
case GGML_OP_SOLVE_TRI:
8521+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8522+
8523+
vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
8524+
8525+
vk_pipeline pipeline = nullptr;
8526+
8527+
{
8528+
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8529+
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
8530+
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
8531+
pipeline = it->second;
8532+
} else {
8533+
ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
8534+
}
8535+
}
8536+
8537+
return pipeline;
8538+
}
8539+
return nullptr;
84998540
case GGML_OP_ARGMAX:
85008541
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
85018542
return ctx->device->pipeline_argmax_f32;
@@ -8832,6 +8873,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
88328873
elements = { nr, 1, 1 };
88338874
}
88348875
} break;
8876+
case GGML_OP_SOLVE_TRI:
8877+
{
8878+
uint32_t nr = (uint32_t)(ne02 * ne03);
8879+
if (nr > 262144) {
8880+
elements = { 512, 512, CEIL_DIV(nr, 262144) };
8881+
} else if (nr > 512) {
8882+
elements = { 512, CEIL_DIV(nr, 512), 1 };
8883+
} else {
8884+
elements = { nr, 1, 1 };
8885+
}
8886+
}
8887+
break;
88358888
case GGML_OP_RMS_NORM:
88368889
if (ctx->do_add_rms_partials) {
88378890
// Run one element per thread, 128 threads per workgroup
@@ -10260,6 +10313,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
1026010313
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
1026110314
}
1026210315

10316+
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10317+
const uint32_t src0_type_size = ggml_type_size(src0->type);
10318+
const uint32_t src1_type_size = ggml_type_size(src1->type);
10319+
const uint32_t dst_type_size = ggml_type_size(dst->type);
10320+
10321+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
10322+
(uint32_t)ggml_nelements(src0),
10323+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10324+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10325+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10326+
0,
10327+
0.0f, 0.0f, 0,
10328+
});
10329+
}
10330+
1026310331
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1026410332
const int32_t s0 = dst->op_params[0];
1026510333
const int32_t s1 = dst->op_params[1];
@@ -11871,6 +11939,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1187111939
case GGML_OP_COUNT_EQUAL:
1187211940
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
1187311941

11942+
break;
11943+
case GGML_OP_SOLVE_TRI:
11944+
ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
11945+
1187411946
break;
1187511947
case GGML_OP_IM2COL:
1187611948
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
@@ -13916,6 +13988,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1391613988
}
1391713989
return false;
1391813990
}
13991+
case GGML_OP_SOLVE_TRI:
13992+
{
13993+
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
13994+
const vk_device& device = ggml_vk_get_device(ctx->device);
13995+
13996+
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
13997+
return false;
13998+
}
13999+
const uint32_t N = op->src[0]->ne[0];
14000+
const uint32_t K = op->src[1]->ne[0];
14001+
// K dimension limited to workgroup size
14002+
if (K > 128) {
14003+
return false;
14004+
}
14005+
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
14006+
return false;
14007+
}
14008+
return true;
14009+
}
1391914010
case GGML_OP_ARGMAX:
1392014011
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1392114012
case GGML_OP_COUNT_EQUAL:
@@ -14588,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1458814679
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
1458914680
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
1459014681
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
14682+
} else if (tensor->op == GGML_OP_SOLVE_TRI) {
14683+
tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
1459114684
} else if (tensor->op == GGML_OP_IM2COL) {
1459214685
const int32_t s0 = tensor->op_params[0];
1459314686
const int32_t s1 = tensor->op_params[1];
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
#include "generic_binary_head.glsl"
5+
6+
layout (constant_id = 1) const uint N = 64;
7+
layout (constant_id = 2) const uint K = 32;
8+
9+
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
10+
11+
uint a_base, b_base, x_base;
12+
13+
FLOAT_TYPE get_a(uint r, uint c) {
14+
return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
15+
}
16+
17+
FLOAT_TYPE get_b(uint r, uint c) {
18+
return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
19+
}
20+
21+
void store_x(uint r, uint c, FLOAT_TYPE v) {
22+
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
23+
}
24+
25+
shared FLOAT_TYPE shA[N * N];
26+
shared FLOAT_TYPE shB[N * K];
27+
28+
void main() {
29+
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
30+
const uint tid = gl_LocalInvocationID.x;
31+
32+
if (batch >= p.ne02 * p.ne03) {
33+
return;
34+
}
35+
36+
const uint i3 = batch / p.ne22;
37+
const uint i2 = batch % p.ne22;
38+
a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
39+
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
40+
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
41+
42+
// Load the A matrix into shA
43+
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
44+
uint idx = i + tid;
45+
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
46+
shA[idx] = get_a(idx / N, idx % N);
47+
}
48+
}
49+
// Load the B matrix into shB
50+
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
51+
uint idx = i + tid;
52+
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
53+
shB[idx] = get_b(idx / K, idx % K);
54+
}
55+
}
56+
barrier();
57+
58+
FLOAT_TYPE X[N];
59+
// Each thread solves one column
60+
if (tid < K) {
61+
[[unroll]] for (int r = 0; r < N; ++r) {
62+
FLOAT_TYPE b = shB[r * K + tid];
63+
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
64+
[[unroll]] for (int c = 0; c < r; ++c) {
65+
b -= shA[r * N + c] * X[c];
66+
}
67+
FLOAT_TYPE x = b / shA[r * N + r];
68+
X[r] = x;
69+
store_x(r, tid, x);
70+
}
71+
}
72+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,8 @@ void process_shaders() {
944944
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
945945
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
946946

947+
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
948+
947949
for (auto transpose : {false, true}) {
948950
for (auto unroll : {false, true}) {
949951
for (auto a_f16 : {false, true}) {

0 commit comments

Comments
 (0)