Skip to content

Commit 34bdbbd

Browse files
authored
vulkan: Remove splitting for mul_mat_id (ggml-org#15568)
row_ids only needs to hold the BN rows for the current tile.
1 parent 74f52f7 commit 34bdbbd

File tree

4 files changed

+35
-56
lines changed

4 files changed

+35
-56
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2090,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
20902090
const uint32_t warps = warptile[0] / warptile[10];
20912091

20922092
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
2093-
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
2093+
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
20942094
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
2095+
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
20952096

2096-
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
2097+
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
20972098
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
20982099

20992100
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -6288,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
62886289

62896290
const uint64_t nei0 = ids->ne[0];
62906291
const uint64_t nei1 = ids->ne[1];
6291-
GGML_ASSERT(nei0 * nei1 <= 4096);
62926292

62936293
const uint32_t nbi1 = ids->nb[1];
62946294
const uint32_t nbi2 = ids->nb[2];
@@ -6728,37 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
67286728
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
67296729
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
67306730
} else {
6731-
// Split based on number of ids, to fit in shared memory
6732-
const uint32_t nei0 = (uint32_t)src2->ne[0];
6733-
const uint32_t nei1 = (uint32_t)src2->ne[1];
6734-
6735-
GGML_ASSERT(nei0 <= 4096);
6736-
const uint32_t split_size = std::min(nei1, 4096u / nei0);
6737-
6738-
if (split_size == nei1) {
6739-
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
6740-
} else {
6741-
ggml_tensor src1_copy = *src1;
6742-
ggml_tensor src2_copy = *src2;
6743-
ggml_tensor dst_copy = *dst;
6744-
6745-
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
6746-
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
6747-
6748-
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
6749-
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
6750-
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
6751-
6752-
src1_copy.ne[2] = n_tokens;
6753-
src2_copy.ne[1] = n_tokens;
6754-
dst_copy.ne[2] = n_tokens;
6755-
6756-
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
6757-
// invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
6758-
ctx->prealloc_y_last_pipeline_used = {};
6759-
ctx->prealloc_y_last_tensor_used = nullptr;
6760-
}
6761-
}
6731+
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
67626732
}
67636733
}
67646734

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
109109
#define NUM_WARPS (BLOCK_SIZE / WARP)
110110

111111
#ifdef MUL_MAT_ID
112-
shared u16vec2 row_ids[4096];
112+
shared u16vec2 row_ids[BN];
113113
uint _ne1;
114114

115115
#ifdef MUL_MAT_ID_USE_SUBGROUPS
116116
shared uvec4 ballots_sh[NUM_WARPS];
117117

118-
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
118+
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
119119
_ne1 = 0;
120120
uint num_elements = p.nei1 * p.nei0;
121121
uint nei0shift = findLSB(p.nei0);
@@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
165165
barrier();
166166

167167
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
168-
if (in_range && id == expert_idx) {
169-
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
168+
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
169+
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
170170
}
171171
_ne1 += total;
172172
iter &= 15;
173+
if (_ne1 >= (ic + 1) * BN) {
174+
break;
175+
}
173176
}
174177
barrier();
175178
}
@@ -242,16 +245,18 @@ void main() {
242245
#ifdef MUL_MAT_ID
243246
#ifdef MUL_MAT_ID_USE_SUBGROUPS
244247
if (bitCount(p.nei0) == 1) {
245-
load_row_ids(expert_idx, true);
248+
load_row_ids(expert_idx, true, ic);
246249
} else {
247-
load_row_ids(expert_idx, false);
250+
load_row_ids(expert_idx, false, ic);
248251
}
249252
#else
250253
_ne1 = 0;
251-
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
252-
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
254+
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
255+
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
253256
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
254-
row_ids[_ne1] = u16vec2(ii0, ii1);
257+
if (_ne1 >= ic * BN) {
258+
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
259+
}
255260
_ne1++;
256261
}
257262
}
@@ -797,7 +802,7 @@ void main() {
797802
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
798803
#if LOAD_VEC_B == 8
799804
#ifdef MUL_MAT_ID
800-
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
805+
const u16vec2 row_idx = row_ids[loadc_b + l];
801806
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
802807
#else
803808
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -813,7 +818,7 @@ void main() {
813818
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
814819
#elif LOAD_VEC_B == 4
815820
#ifdef MUL_MAT_ID
816-
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
821+
const u16vec2 row_idx = row_ids[loadc_b + l];
817822
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
818823
#else
819824
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -832,7 +837,7 @@ void main() {
832837
#else
833838
const uint row_i = ic * BN + loadc_b + l;
834839
if (row_i < _ne1 && block + loadr_b < end_k) {
835-
const u16vec2 row_idx = row_ids[row_i];
840+
const u16vec2 row_idx = row_ids[loadc_b + l];
836841
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
837842
} else {
838843
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -903,7 +908,7 @@ void main() {
903908
const uint row_i = dc + cm_col * TN + col + store_c;
904909
if (row_i >= _ne1) break;
905910

906-
const u16vec2 row_idx = row_ids[row_i];
911+
const u16vec2 row_idx = row_ids[row_i - ic * BN];
907912

908913
if (dr + cm_row * TM + store_r < p.M) {
909914
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -953,7 +958,7 @@ void main() {
953958
const uint row_i = dc_warp + cc;
954959
if (row_i >= _ne1) break;
955960

956-
const u16vec2 row_idx = row_ids[row_i];
961+
const u16vec2 row_idx = row_ids[row_i - ic * BN];
957962
#endif // MUL_MAT_ID
958963
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
959964
#ifdef MUL_MAT_ID

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
9393
#ifdef MUL_MAT_ID
9494
layout (binding = 3) readonly buffer IDS {int data_ids[];};
9595

96-
shared u16vec4 row_ids[4096];
96+
shared u16vec4 row_ids[BN];
9797

9898
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
9999
B_TYPE b[];
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
111111
return B_TYPE(0.0);
112112
}
113113

114-
const u16vec4 row_idx = row_ids[row_i];
114+
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
115115
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
116116

117117
return ret;
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
123123
uint dc = ic * BN + c;
124124

125125
if (dr < p.M && dc < _ne1) {
126-
uint row_i = dc;
126+
uint row_i = c;
127127
const u16vec4 row_idx = row_ids[row_i];
128128
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
129129
}
130130
return elem;
131131
}
132132

133-
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
133+
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
134134
_ne1 = 0;
135135
uint num_elements = p.nei1 * p.nei0;
136136
uint nei0shift = findLSB(p.nei0);
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
180180
barrier();
181181

182182
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
183-
if (in_range && id == expert_idx) {
184-
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
183+
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
184+
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
185185
}
186186
_ne1 += total;
187187
iter &= 15;
188+
if (_ne1 >= (ic + 1) * BN) {
189+
break;
190+
}
188191
}
189192
barrier();
190193
}
@@ -218,9 +221,9 @@ void main() {
218221

219222
#ifdef MUL_MAT_ID
220223
if (bitCount(p.nei0) == 1) {
221-
load_row_ids(expert_idx, true);
224+
load_row_ids(expert_idx, true, ic);
222225
} else {
223-
load_row_ids(expert_idx, false);
226+
load_row_ids(expert_idx, false, ic);
224227
}
225228

226229
// Workgroup has no work

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6017,6 +6017,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
60176017
// test large experts*tokens
60186018
for (bool b : {false, true}) {
60196019
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
6020+
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
60206021
}
60216022

60226023
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));

0 commit comments

Comments
 (0)