Skip to content

Commit 2b54086

Browse files
committed
vulkan: optimize mat_mul_id row_ids search to batch loads, and port to coopmat1 path
1 parent b54ddba commit 2b54086

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1821,7 +1821,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
18211821
const uint32_t warps = warptile[0] / warptile[10];
18221822

18231823
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1824-
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1824+
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0;
18251825
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
18261826

18271827
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#extension GL_KHR_cooperative_matrix : enable
1919
#extension GL_KHR_memory_scope_semantics : enable
2020
#extension GL_KHR_shader_subgroup_basic : enable
21+
#extension GL_KHR_shader_subgroup_ballot : enable
2122
#endif
2223

2324
#ifdef MUL_MAT_ID
@@ -104,6 +105,10 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
104105

105106
#ifdef MUL_MAT_ID
106107
shared u16vec2 row_ids[4096];
108+
uint _ne1;
109+
#ifdef COOPMAT
110+
shared uint _ne1_sh;
111+
#endif
107112
#endif // MUL_MAT_ID
108113

109114
#define NUM_WARPS (BLOCK_SIZE / WARP)
@@ -172,7 +177,47 @@ void main() {
172177
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
173178

174179
#ifdef MUL_MAT_ID
175-
uint _ne1 = 0;
180+
#ifdef COOPMAT
181+
// Spread the search across all elements in the first subgroup
182+
if (gl_SubgroupID == 0) {
183+
_ne1 = 0;
184+
uint num_elements = p.nei1 * p.nei0;
185+
186+
uint ids[16];
187+
uint iter = 0;
188+
189+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
190+
// prefetch up to 16 elements
191+
if (iter == 0) {
192+
[[unroll]] for (uint k = 0; k < 16; ++k) {
193+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
194+
bool in_range = i < num_elements;
195+
uint ii1 = i / p.nei0;
196+
uint ii0 = i % p.nei0;
197+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
198+
}
199+
}
200+
uint i = j + gl_SubgroupInvocationID;
201+
bool in_range = i < num_elements;
202+
uint ii1 = i / p.nei0;
203+
uint ii0 = i % p.nei0;
204+
uint id = ids[iter++];
205+
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
206+
uint idx = subgroupBallotExclusiveBitCount(ballot);
207+
if (in_range && id == expert_idx) {
208+
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
209+
}
210+
_ne1 += subgroupBallotBitCount(ballot);
211+
iter &= 15;
212+
}
213+
_ne1_sh = _ne1;
214+
}
215+
216+
barrier();
217+
218+
_ne1 = _ne1_sh;
219+
#else
220+
_ne1 = 0;
176221
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
177222
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
178223
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
@@ -183,6 +228,7 @@ void main() {
183228
}
184229

185230
barrier();
231+
#endif
186232

187233
// Workgroup has no work
188234
if (ic * BN >= _ne1) return;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,32 @@ void main() {
162162
_ne1 = 0;
163163
uint num_elements = p.nei1 * p.nei0;
164164

165-
for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) {
165+
uint ids[16];
166+
uint iter = 0;
167+
168+
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
169+
// prefetch up to 16 elements
170+
if (iter == 0) {
171+
[[unroll]] for (uint k = 0; k < 16; ++k) {
172+
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
173+
bool in_range = i < num_elements;
174+
uint ii1 = i / p.nei0;
175+
uint ii0 = i % p.nei0;
176+
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
177+
}
178+
}
179+
uint i = j + gl_SubgroupInvocationID;
166180
bool in_range = i < num_elements;
167-
uint ii0 = i % p.nei0;
168181
uint ii1 = i / p.nei0;
169-
uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
182+
uint ii0 = i % p.nei0;
183+
uint id = ids[iter++];
170184
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
171185
uint idx = subgroupBallotExclusiveBitCount(ballot);
172186
if (in_range && id == expert_idx) {
173187
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
174188
}
175189
_ne1 += subgroupBallotBitCount(ballot);
190+
iter &= 15;
176191
}
177192
_ne1_sh = _ne1;
178193
}

0 commit comments

Comments
 (0)