Skip to content

Commit 9271981

Browse files
committed
add q3_k mmvq
1 parent 3aa74df commit 9271981

File tree

3 files changed

+51
-2
lines changed

3 files changed

+51
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3409,6 +3409,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
34093409
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34103410

34113411
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
3412+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
34123413
}
34133414
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
34143415
}
@@ -5149,6 +5150,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
51495150
case GGML_TYPE_Q8_0:
51505151
case GGML_TYPE_MXFP4:
51515152
case GGML_TYPE_Q2_K:
5153+
case GGML_TYPE_Q3_K:
51525154
break;
51535155
default:
51545156
return nullptr;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,53 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
194194
}
195195
#endif
196196

197+
#if defined(DATA_A_Q3_K)
198+
// 2-byte loads for Q3_K blocks (110 bytes)
199+
i32vec2 repack2(uint ib, uint iqs) {
200+
const uint ib_k = ib / 8;
201+
const uint iqs_k = (ib % 8) * 8 + iqs;
202+
203+
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
204+
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
205+
const uint hm_shift = iqs_k / 8;
206+
207+
// bitwise OR to add 4 if hmask is set, subtract later
208+
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
209+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
210+
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) |
211+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
212+
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) |
213+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
214+
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) |
215+
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
216+
217+
return i32vec2(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
218+
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)));
219+
}
220+
221+
float get_d_scale(uint ib, uint iqs) {
222+
const uint ib_k = ib / 8;
223+
const uint iqs_k = (ib % 8) * 8 + iqs;
224+
const uint is = iqs_k / 4;
225+
226+
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
227+
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
228+
return float(data_a[ib_k].d) * float(scale - 32);
229+
}
230+
231+
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
232+
int32_t q_sum = 0;
233+
234+
const i32vec2 qs_a = repack2(ib_a, iqs * 2);
235+
const float d_scale = get_d_scale(ib_a, iqs * 2);
236+
237+
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
238+
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
239+
240+
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
241+
}
242+
#endif
243+
197244
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
198245
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
199246
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ void process_shaders() {
664664

665665
// mul mat vec with integer dot product
666666
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
667-
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k") {
667+
if (is_legacy_quant(tname) || tname == "mxfp4" || tname == "q2_k" || tname == "q3_k") {
668668
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
669669
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
670670
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@@ -1040,7 +1040,7 @@ void write_output_files() {
10401040

10411041
for (const std::string& btype : btypes) {
10421042
for (const auto& tname : type_names) {
1043-
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k") {
1043+
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && tname != "q2_k" && tname != "q3_k") {
10441044
continue;
10451045
}
10461046
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

0 commit comments

Comments
 (0)