Skip to content

Commit da201d6

Browse files
committed
add mxfp4 mmvq
1 parent aad4752 commit da201d6

File tree

4 files changed

+49
-20
lines changed

4 files changed

+49
-20
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3405,6 +3405,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
34053405
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34063406
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34073407
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3408+
3409+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3410+
3411+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34083412
}
34093413
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
34103414
}
@@ -5143,6 +5147,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
51435147
case GGML_TYPE_Q5_0:
51445148
case GGML_TYPE_Q5_1:
51455149
case GGML_TYPE_Q8_0:
5150+
case GGML_TYPE_MXFP4:
5151+
case GGML_TYPE_Q2_K:
51465152
break;
51475153
default:
51485154
return nullptr;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
4343
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
4444
ibi += p.ncols;
4545

46-
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx, 4);
46+
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
4747
}
4848
}
4949
}
@@ -108,6 +108,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
108108
void main() {
109109
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
110110

111+
#ifdef NEEDS_INIT_IQ_SHMEM
112+
init_iq_shmem(gl_WorkGroupSize);
113+
#endif
114+
111115
// do NUM_ROWS at a time, unless there aren't enough remaining rows
112116
if (first_row + NUM_ROWS <= p.stride_d) {
113117
compute_outputs(first_row, NUM_ROWS);

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ FLOAT_TYPE get_dm(uint ib) {
1010
}
1111
#endif
1212

13-
#if defined(DATA_A_MXFP4)
14-
FLOAT_TYPE get_dm(uint ib) {
15-
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
16-
}
17-
#endif
18-
1913
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
2014
FLOAT_TYPE_VEC2 get_dm(uint ib) {
2115
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
2216
}
2317
#endif
2418

19+
#if defined(DATA_A_MXFP4)
20+
FLOAT_TYPE get_dm(uint ib) {
21+
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
22+
}
23+
#endif
24+
2525
#if defined(DATA_A_Q2_K)
2626
FLOAT_TYPE_VEC2 get_dm(uint ib) {
2727
const uint ib_k = ib / 8;
@@ -115,22 +115,25 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int
115115
#if defined(DATA_A_MXFP4)
116116
// 1-byte loads for mxfp4 blocks (17 bytes)
117117
i32vec2 repack(uint ib, uint iqs) {
118-
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
119-
data_a[ib].qs[iqs * 4 + 1],
120-
data_a[ib].qs[iqs * 4 + 2],
121-
data_a[ib].qs[iqs * 4 + 3]));
118+
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
119+
data_a[ib].qs[iqs * 4 + 1],
120+
data_a[ib].qs[iqs * 4 + 2],
121+
data_a[ib].qs[iqs * 4 + 3]));
122+
123+
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
124+
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
122125

123-
return i32vec2( quants & 0x0F0F0F0F,
124-
(quants >> 4) & 0x0F0F0F0F);
126+
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
127+
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
125128
}
126129

127130
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
128-
return ACC_TYPE(da * dsb.x * float(q_sum));
131+
return ACC_TYPE(da * dsb.x * float(q_sum) * 0.5);
129132
}
130133
#endif
131134

132135
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
133-
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_divisor) {
136+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
134137
int32_t q_sum = 0;
135138
#if QUANT_R == 2
136139
const i32vec2 data_a_qs = repack(ib_a, iqs);
@@ -147,7 +150,8 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs, const int32_t sum_d
147150
cache_b_qs[1]);
148151
#endif
149152

150-
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, sum_divisor);
153+
// 2 quants per call => divide sums by 8/2 = 4
154+
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
151155
}
152156
#endif
153157

@@ -170,8 +174,23 @@ uint8_t get_scale(uint ib, uint iqs) {
170174
return data_a[ib_k].scales[iqs_k / 4];
171175
}
172176

173-
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
174-
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
177+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
178+
int32_t sum_d = 0;
179+
int32_t sum_m = 0;
180+
181+
const int32_t qs_a0 = repack(ib_a, iqs * 2);
182+
const int32_t qs_a1 = repack(ib_a, iqs * 2 + 1);
183+
const uint8_t scale = get_scale(ib_a, iqs * 2);
184+
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
185+
186+
sum_d += dotPacked4x8EXT(qs_a0, cache_b_qs[0]) * (scale & 0xF);
187+
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
188+
189+
sum_d += dotPacked4x8EXT(qs_a1, cache_b_qs[1]) * (scale & 0xF);
190+
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
191+
192+
const vec2 dm = get_dm(ib_a);
193+
return ACC_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m) / 4));
175194
}
176195
#endif
177196

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ void process_shaders() {
664664

665665
// mul mat vec with integer dot product
666666
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
667-
if (is_legacy_quant(tname)) {
667+
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k") {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -1040,7 +1040,7 @@ void write_output_files() {
10401040

10411041
for (const std::string& btype : btypes) {
10421042
for (const auto& tname : type_names) {
1043-
if (btype == "q8_1" && !is_legacy_quant(tname)) {
1043+
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k") {
10441044
continue;
10451045
}
10461046
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

0 commit comments

Comments
 (0)