Skip to content

Commit a527b9c

Browse files
committed
Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq
1 parent e0dedb2 commit a527b9c

File tree

3 files changed

+17
-39
lines changed

3 files changed

+17
-39
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,14 +1899,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
18991899
if (device->mul_mat ## ID ## _s[TYPE]) \
19001900
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
19011901

1902-
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1903-
if (device->mul_mat ## ID ## _l[TYPE]) \
1904-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1905-
if (device->mul_mat ## ID ## _m[TYPE]) \
1906-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1907-
if (device->mul_mat ## ID ## _s[TYPE]) \
1908-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1909-
19101902
// Create 2 variants, {f16,f32} accumulator
19111903
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
19121904
if (device->coopmat_acc_f16_support) { \
@@ -2013,7 +2005,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
20132005
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
20142006
}
20152007
#undef CREATE_MM2
2016-
#undef CREATE_MMQ
20172008
#undef CREATE_MM
20182009
} else
20192010
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
@@ -4151,7 +4142,7 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
41514142
return aligned ? mmp->a_s : mmp->s;
41524143
}
41534144

4154-
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type]) || src1_type == GGML_TYPE_Q8_1) {
4145+
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
41554146
return aligned ? mmp->a_s : mmp->s;
41564147
}
41574148
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
113113
#include "mul_mmq_funcs.comp"
114114

115115
void main() {
116-
#if defined(DATA_A_IQ4_NL)
117-
init_iq4nl_shmem();
116+
#ifdef NEEDS_INIT_IQ_SHMEM
117+
init_iq_shmem(gl_WorkGroupSize);
118118
#endif
119119

120120
#ifdef MUL_MAT_ID
@@ -347,9 +347,8 @@ void main() {
347347
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
348348
int32_t q_sum = 0;
349349
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
350-
q_sum = dotPacked4x8AccSatEXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
351-
cache_b_qs[cc * (BK / 4) + idx_k],
352-
q_sum);
350+
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
351+
cache_b_qs[cc * (BK / 4) + idx_k]);
353352
}
354353

355354
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,12 @@ i32vec2 repack(uint ib, uint iqs) {
4040
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
4141
data_a[ib].qs[iqs * 2 + 1]);
4242
const uint32_t vui = pack32(quants);
43-
const uint32_t qh = (uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs);
44-
int32_t v0 = int32_t(vui & 0x0F0F0F0F);
45-
v0 |= int32_t((qh << 4) & 0x00000010); // 0 -> 4
46-
v0 |= int32_t((qh << 11) & 0x00001000); // 1 -> 12
47-
v0 |= int32_t((qh << 18) & 0x00100000); // 2 -> 20
48-
v0 |= int32_t((qh << 25) & 0x10000000); // 3 -> 28
49-
50-
int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F);
51-
v1 |= int32_t((qh >> 12) & 0x00000010); // 16 -> 4
52-
v1 |= int32_t((qh >> 5) & 0x00001000); // 17 -> 12
53-
v1 |= int32_t((qh << 2) & 0x00100000); // 18 -> 20
54-
v1 |= int32_t((qh << 9) & 0x10000000); // 19 -> 28
43+
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
44+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
45+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
46+
47+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
48+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
5549

5650
return i32vec2(v0, v1);
5751
}
@@ -65,18 +59,12 @@ ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
6559
i32vec2 repack(uint ib, uint iqs) {
6660
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
6761
const uint32_t vui = data_a_packed32[ib].qs[iqs];
68-
const uint32_t qh = data_a_packed32[ib].qh >> (4 * iqs);
69-
int32_t v0 = int32_t(vui & 0x0F0F0F0F);
70-
v0 |= int32_t((qh << 4) & 0x00000010); // 0 -> 4
71-
v0 |= int32_t((qh << 11) & 0x00001000); // 1 -> 12
72-
v0 |= int32_t((qh << 18) & 0x00100000); // 2 -> 20
73-
v0 |= int32_t((qh << 25) & 0x10000000); // 3 -> 28
74-
75-
int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F);
76-
v1 |= int32_t((qh >> 12) & 0x00000010); // 16 -> 4
77-
v1 |= int32_t((qh >> 5) & 0x00001000); // 17 -> 12
78-
v1 |= int32_t((qh << 2) & 0x00100000); // 18 -> 20
79-
v1 |= int32_t((qh << 9) & 0x10000000); // 19 -> 28
62+
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
63+
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
64+
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
65+
66+
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
67+
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
8068

8169
return i32vec2(v0, v1);
8270
}

0 commit comments

Comments
 (0)