Skip to content

Commit c6c5e85

Browse files
authored
vulkan: support solve_tri with larger N/K values (ggml-org#17781)
Split N into chunks to fit into shared memory. If K > 128, use a larger workgroup with enough invocations. Add perf tests matching qwen3next.
1 parent 8e5f498 commit c6c5e85

File tree

3 files changed

+62
-30
lines changed

3 files changed

+62
-30
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4033,10 +4033,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
40334033

40344034
for (auto &s : device->pipeline_solve_tri_f32) {
40354035
const vk_solve_tri_pipeline_state &state = s.first;
4036+
4037+
// Max number of rows to load at a time, limited by shared memory
4038+
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
4039+
// Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
4040+
const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
4041+
40364042
ggml_vk_create_pipeline(
40374043
device, s.second, "solve_tri_f32",
40384044
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
4039-
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
4045+
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
40404046
}
40414047

40424048
#define IM2COL(bda) \
@@ -14025,10 +14031,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1402514031
const uint32_t N = op->src[0]->ne[0];
1402614032
const uint32_t K = op->src[1]->ne[0];
1402714033
// K dimension limited to workgroup size
14028-
if (K > 128) {
14034+
if (K > 1u << device->max_workgroup_size_log2) {
1402914035
return false;
1403014036
}
14031-
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
14037+
const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
14038+
14039+
if (batch_N == 0) {
1403214040
return false;
1403314041
}
1403414042
return true;

ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
layout (constant_id = 1) const uint N = 64;
77
layout (constant_id = 2) const uint K = 32;
8+
layout (constant_id = 3) const uint BATCH_N = 32;
89

9-
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
10+
layout(local_size_x_id = 4, local_size_y = 1, local_size_z = 1) in;
1011

1112
uint a_base, b_base, x_base;
1213

@@ -22,8 +23,8 @@ void store_x(uint r, uint c, FLOAT_TYPE v) {
2223
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
2324
}
2425

25-
shared FLOAT_TYPE shA[N * N];
26-
shared FLOAT_TYPE shB[N * K];
26+
shared FLOAT_TYPE shA[BATCH_N * N];
27+
shared FLOAT_TYPE shB[BATCH_N * K];
2728

2829
void main() {
2930
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
@@ -39,34 +40,42 @@ void main() {
3940
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
4041
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
4142

42-
// Load the A matrix into shA
43-
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
44-
uint idx = i + tid;
45-
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
46-
shA[idx] = get_a(idx / N, idx % N);
43+
FLOAT_TYPE X[N];
44+
45+
// Loop over batches of rows
46+
[[unroll]] for (uint row_base = 0; row_base < N; row_base += BATCH_N) {
47+
const uint cur_N = min(BATCH_N, N - row_base);
48+
49+
// Load the A matrix batch into shA
50+
[[unroll]] for (uint i = 0; i < cur_N * N; i += gl_WorkGroupSize.x) {
51+
uint idx = i + tid;
52+
if (((cur_N * N) % gl_WorkGroupSize.x == 0) || idx < cur_N * N) {
53+
shA[idx] = get_a(row_base + idx / N, idx % N);
54+
}
4755
}
48-
}
49-
// Load the B matrix into shB
50-
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
51-
uint idx = i + tid;
52-
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
53-
shB[idx] = get_b(idx / K, idx % K);
56+
// Load the B matrix batch into shB
57+
[[unroll]] for (uint i = 0; i < cur_N * K; i += gl_WorkGroupSize.x) {
58+
uint idx = i + tid;
59+
if (((cur_N * K) % gl_WorkGroupSize.x == 0) || idx < cur_N * K) {
60+
shB[idx] = get_b(row_base + idx / K, idx % K);
61+
}
5462
}
55-
}
56-
barrier();
63+
barrier();
5764

58-
FLOAT_TYPE X[N];
59-
// Each thread solves one column
60-
if (tid < K) {
61-
[[unroll]] for (int r = 0; r < N; ++r) {
62-
FLOAT_TYPE b = shB[r * K + tid];
63-
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
64-
[[unroll]] for (int c = 0; c < r; ++c) {
65-
b -= shA[r * N + c] * X[c];
65+
// Each thread solves one column
66+
if (tid < K) {
67+
[[unroll]] for (uint row_offset = 0; row_offset < cur_N; ++row_offset) {
68+
uint r = row_base + row_offset;
69+
FLOAT_TYPE b = shB[row_offset * K + tid];
70+
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
71+
[[unroll]] for (int c = 0; c < r; ++c) {
72+
b -= shA[row_offset * N + c] * X[c];
73+
}
74+
FLOAT_TYPE x = b / shA[row_offset * N + r];
75+
X[r] = x;
76+
store_x(r, tid, x);
6677
}
67-
FLOAT_TYPE x = b / shA[r * N + r];
68-
X[r] = x;
69-
store_x(r, tid, x);
7078
}
79+
barrier();
7180
}
7281
}

tests/test-backend-ops.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6204,6 +6204,15 @@ struct test_solve_tri : public test_case {
62046204

62056205
std::string vars() override { return VARS_TO_STR3(type, ne_lhs, ne_rhs); }
62066206

6207+
uint64_t op_flops(ggml_tensor * t) override {
6208+
GGML_UNUSED(t);
6209+
int64_t n = ne_lhs[0];
6210+
int64_t k = ne_rhs[0];
6211+
int64_t batch = ne_lhs[2] * ne_lhs[3];
6212+
// n * (n + 1) / 2 non-zero elements of lhs, 2 flops each, for each col of rhs
6213+
return n * (n + 1) * k * batch;
6214+
}
6215+
62076216
test_solve_tri(ggml_type type = GGML_TYPE_F32,
62086217
std::array<int64_t, 4> ne_lhs = { 10, 10, 4, 3 },
62096218
std::array<int64_t, 4> ne_rhs = { 3, 10, 4, 3 }
@@ -7816,6 +7825,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
78167825
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
78177826
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
78187827
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
7828+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
7829+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
78197830

78207831
for (bool v : {false, true}) {
78217832
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v));
@@ -8016,6 +8027,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
80168027

80178028
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
80188029
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
8030+
// qwen3next with CHUNK_SIZE 64
8031+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
8032+
// qwen3next with CHUNK_SIZE 128
8033+
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
80198034

80208035
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
80218036
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));

0 commit comments

Comments
 (0)