Skip to content

Commit 330c3d2

Browse files
authored
vulkan: optimize mul_mat_id loading row ids into shared memory (ggml-org#15427)
- Spread the work across the whole workgroup. Using more threads seems to far outweigh the synchronization overhead. - Specialize the code for when the division is by a power of two.
1 parent e92734d commit 330c3d2

File tree

3 files changed

+133
-81
lines changed

3 files changed

+133
-81
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
21682168
s_mmq_wg_denoms_k = { 32, 64, 1 };
21692169

21702170
// spec constants and tile sizes for quant matmul_id
2171-
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
2172-
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
2173-
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
2171+
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
2172+
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
2173+
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
21742174
l_mmqid_wg_denoms = { 128, 128, 1 };
21752175
m_mmqid_wg_denoms = { 128, 64, 1 };
21762176
s_mmqid_wg_denoms = { 128, 64, 1 };

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,74 @@ layout (constant_id = 10) const uint WARP = 32;
103103
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
104104
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
105105

106+
#define NUM_WARPS (BLOCK_SIZE / WARP)
107+
106108
#ifdef MUL_MAT_ID
107109
shared u16vec2 row_ids[4096];
108110
uint _ne1;
109111
#ifdef COOPMAT
110-
shared uint _ne1_sh;
112+
shared uvec4 ballots_sh[NUM_WARPS];
113+
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
114+
_ne1 = 0;
115+
uint num_elements = p.nei1 * p.nei0;
116+
uint nei0shift = findLSB(p.nei0);
117+
118+
uint ids[16];
119+
uint iter = 0;
120+
121+
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
122+
// prefetch up to 16 elements
123+
if (iter == 0) {
124+
[[unroll]] for (uint k = 0; k < 16; ++k) {
125+
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
126+
bool in_range = i < num_elements;
127+
uint ii1;
128+
if (nei0_is_pow2) {
129+
ii1 = i >> nei0shift;
130+
} else {
131+
ii1 = i / p.nei0;
132+
}
133+
uint ii0 = i - ii1 * p.nei0;
134+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
135+
}
136+
}
137+
uint i = j + gl_LocalInvocationIndex;
138+
bool in_range = i < num_elements;
139+
uint ii1;
140+
if (nei0_is_pow2) {
141+
ii1 = i >> nei0shift;
142+
} else {
143+
ii1 = i / p.nei0;
144+
}
145+
uint ii0 = i - ii1 * p.nei0;
146+
uint id = ids[iter++];
147+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
148+
149+
ballots_sh[gl_SubgroupID] = ballot;
150+
barrier();
151+
152+
uint subgroup_base = 0;
153+
uint total = 0;
154+
for (uint k = 0; k < gl_NumSubgroups; ++k) {
155+
if (k == gl_SubgroupID) {
156+
subgroup_base = total;
157+
}
158+
total += subgroupBallotBitCount(ballots_sh[k]);
159+
}
160+
barrier();
161+
162+
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
163+
if (in_range && id == expert_idx) {
164+
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
165+
}
166+
_ne1 += total;
167+
iter &= 15;
168+
}
169+
barrier();
170+
}
111171
#endif
112172
#endif // MUL_MAT_ID
113173

114-
#define NUM_WARPS (BLOCK_SIZE / WARP)
115-
116174
#ifdef COOPMAT
117175
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
118176
#endif
@@ -178,44 +236,11 @@ void main() {
178236

179237
#ifdef MUL_MAT_ID
180238
#ifdef COOPMAT
181-
// Spread the search across all elements in the first subgroup
182-
if (gl_SubgroupID == 0) {
183-
_ne1 = 0;
184-
uint num_elements = p.nei1 * p.nei0;
185-
186-
uint ids[16];
187-
uint iter = 0;
188-
189-
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
190-
// prefetch up to 16 elements
191-
if (iter == 0) {
192-
[[unroll]] for (uint k = 0; k < 16; ++k) {
193-
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
194-
bool in_range = i < num_elements;
195-
uint ii1 = i / p.nei0;
196-
uint ii0 = i % p.nei0;
197-
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
198-
}
199-
}
200-
uint i = j + gl_SubgroupInvocationID;
201-
bool in_range = i < num_elements;
202-
uint ii1 = i / p.nei0;
203-
uint ii0 = i % p.nei0;
204-
uint id = ids[iter++];
205-
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
206-
uint idx = subgroupBallotExclusiveBitCount(ballot);
207-
if (in_range && id == expert_idx) {
208-
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
209-
}
210-
_ne1 += subgroupBallotBitCount(ballot);
211-
iter &= 15;
212-
}
213-
_ne1_sh = _ne1;
239+
if (bitCount(p.nei0) == 1) {
240+
load_row_ids(expert_idx, true);
241+
} else {
242+
load_row_ids(expert_idx, false);
214243
}
215-
216-
barrier();
217-
218-
_ne1 = _ne1_sh;
219244
#else
220245
_ne1 = 0;
221246
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#endif
2020

2121
#include "types.comp"
22+
#include "utils.comp"
2223

2324
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
2425

@@ -99,7 +100,8 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
99100
};
100101

101102
uint _ne1;
102-
shared uint _ne1_sh;
103+
layout (constant_id = 5) const uint subgroup_size = 32;
104+
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
103105

104106
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
105107
{
@@ -128,6 +130,64 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
128130
return elem;
129131
}
130132

133+
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
134+
_ne1 = 0;
135+
uint num_elements = p.nei1 * p.nei0;
136+
uint nei0shift = findLSB(p.nei0);
137+
138+
uint ids[16];
139+
uint iter = 0;
140+
141+
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
142+
// prefetch up to 16 elements
143+
if (iter == 0) {
144+
[[unroll]] for (uint k = 0; k < 16; ++k) {
145+
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
146+
bool in_range = i < num_elements;
147+
uint ii1;
148+
if (nei0_is_pow2) {
149+
ii1 = i >> nei0shift;
150+
} else {
151+
ii1 = i / p.nei0;
152+
}
153+
uint ii0 = i - ii1 * p.nei0;
154+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
155+
}
156+
}
157+
uint i = j + gl_LocalInvocationIndex;
158+
bool in_range = i < num_elements;
159+
uint ii1;
160+
if (nei0_is_pow2) {
161+
ii1 = i >> nei0shift;
162+
} else {
163+
ii1 = i / p.nei0;
164+
}
165+
uint ii0 = i - ii1 * p.nei0;
166+
uint id = ids[iter++];
167+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
168+
169+
ballots_sh[gl_SubgroupID] = ballot;
170+
barrier();
171+
172+
uint subgroup_base = 0;
173+
uint total = 0;
174+
for (uint k = 0; k < gl_NumSubgroups; ++k) {
175+
if (k == gl_SubgroupID) {
176+
subgroup_base = total;
177+
}
178+
total += subgroupBallotBitCount(ballots_sh[k]);
179+
}
180+
barrier();
181+
182+
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
183+
if (in_range && id == expert_idx) {
184+
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
185+
}
186+
_ne1 += total;
187+
iter &= 15;
188+
}
189+
barrier();
190+
}
131191
#endif
132192

133193
void main() {
@@ -157,45 +217,12 @@ void main() {
157217
const uint ic = gl_WorkGroupID.y;
158218

159219
#ifdef MUL_MAT_ID
160-
// Spread the search across all elements in the first subgroup
161-
if (gl_SubgroupID == 0) {
162-
_ne1 = 0;
163-
uint num_elements = p.nei1 * p.nei0;
164-
165-
uint ids[16];
166-
uint iter = 0;
167-
168-
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169-
// prefetch up to 16 elements
170-
if (iter == 0) {
171-
[[unroll]] for (uint k = 0; k < 16; ++k) {
172-
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173-
bool in_range = i < num_elements;
174-
uint ii1 = i / p.nei0;
175-
uint ii0 = i % p.nei0;
176-
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177-
}
178-
}
179-
uint i = j + gl_SubgroupInvocationID;
180-
bool in_range = i < num_elements;
181-
uint ii1 = i / p.nei0;
182-
uint ii0 = i % p.nei0;
183-
uint id = ids[iter++];
184-
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
185-
uint idx = subgroupBallotExclusiveBitCount(ballot);
186-
if (in_range && id == expert_idx) {
187-
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
188-
}
189-
_ne1 += subgroupBallotBitCount(ballot);
190-
iter &= 15;
191-
}
192-
_ne1_sh = _ne1;
220+
if (bitCount(p.nei0) == 1) {
221+
load_row_ids(expert_idx, true);
222+
} else {
223+
load_row_ids(expert_idx, false);
193224
}
194225

195-
barrier();
196-
197-
_ne1 = _ne1_sh;
198-
199226
// Workgroup has no work
200227
if (ic * BN >= _ne1) return;
201228
#endif

0 commit comments

Comments
 (0)