Skip to content

Commit 19d5bbb

Browse files
committed
Q4_0 Supported.
1 parent 802782c commit 19d5bbb

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int
413413
}
414414

415415
for (int kfactor: kfactors) {
416-
if (kfactor < kernel_config.actk) {
416+
if ((kfactor < kernel_config.actk) || (kfactor * kernel_config.g > kernel_config.q_group_size)) {
417417
continue;
418418
}
419419

@@ -455,7 +455,7 @@ static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int
455455

456456
int largest_kfactor = 0;
457457
for (int kfactor: kfactors) {
458-
if (kfactor < kernel_config.actk) {
458+
if ((kfactor < kernel_config.actk) || (kfactor * kernel_config.g > kernel_config.q_group_size)) {
459459
continue;
460460
}
461461
if (kfactor > largest_kfactor) {
@@ -468,8 +468,8 @@ static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int
468468

469469
// Save the results
470470
insert_or_assign_tmac_kernel_config(M, K, bits, best_kcfg);
471-
GGML_LOG_INFO("Tuned kernel config: M=%d, N=%d, K=%d, bm=%d, kfactor=%d, bits=%d, g=%d, ngroups_per_elem=%d, q_group_size=%d, act_group_size=%d\n",
472-
M, N, K, best_kcfg.bm, best_kcfg.kfactor, bits, best_kcfg.g, best_kcfg.ngroups_per_elem, best_kcfg.q_group_size, best_kcfg.act_group_size);
471+
GGML_LOG_INFO("Tuned kernel config: M=%d, N=%d, K=%d, bm=%d, kfactor=%d, bits=%d, actk=%d, g=%d, ngroups_per_elem=%d, q_group_size=%d, act_group_size=%d\n",
472+
M, N, K, best_kcfg.bm, best_kcfg.kfactor, bits, best_kcfg.actk, best_kcfg.g, best_kcfg.ngroups_per_elem, best_kcfg.q_group_size, best_kcfg.act_group_size);
473473
}
474474

475475

ggml/src/ggml-cpu/tmac/tbl.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,14 +834,30 @@ void qgemm_lut_int8_g4(
834834
tbl_int32_reset(bm * sizeof(tmac_float_type) / sizeof(int32_t), (&(((int32_t*)CBits)[0])));
835835

836836
int32_t k_outer_max = K / (kfactor * g);
837+
int32_t scale_gs = q_group_size / (kfactor * g);
838+
int32_t scale_idx_shfr = 0;
839+
if (scale_gs == 1) {
840+
scale_idx_shfr = 0;
841+
} else if (scale_gs == 2) {
842+
scale_idx_shfr = 1;
843+
} else if (scale_gs == 4) {
844+
scale_idx_shfr = 2;
845+
} else if (scale_gs == 8) {
846+
scale_idx_shfr = 3;
847+
} else {
848+
fprintf(stderr, "q_group_size=%d, kfactor=%d, g=%d\n", q_group_size, kfactor, g);
849+
fprintf(stderr, "Unsupported scale group size over kfactor. Expected {1,2,4,8}, got %d.\n", scale_gs);
850+
throw std::runtime_error("");
851+
}
852+
837853
for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) {
838854
uint8_t * a = ((uint8_t *)A) + k_outer * bm * kfactor / ngroups_per_elem;
839855
tmac_float_type * scales = one_scale ? (tmac_float_type *)Scales :
840-
has_zero_point ? ((tmac_float_type *)Scales) + k_outer * m * 2:
841-
((tmac_float_type *)Scales) + k_outer * m;
856+
has_zero_point ? ((tmac_float_type *)Scales) + (k_outer >> scale_idx_shfr) * m * 2:
857+
((tmac_float_type *)Scales) + (k_outer >> scale_idx_shfr) * m;
842858
int8_t * lut = ((int8_t *)LUT) + k_outer * kfactor * int(pow(2, g));
843-
tmac_float_type * lut_scales = ((tmac_float_type *)LUT_Scales) + (k_outer * q_group_size / act_group_size); // k_outer * kfactor * g / act_group_size == k_outer
844-
tmac_float_type * lut_biases = ((tmac_float_type *)LUT_Biases) + (k_outer * q_group_size / act_group_size); // k_outer * kfactor * g / act_group_size == k_outer
859+
tmac_float_type * lut_scales = ((tmac_float_type *)LUT_Scales) + (k_outer * kfactor * g / act_group_size);
860+
tmac_float_type * lut_biases = ((tmac_float_type *)LUT_Biases) + (k_outer * kfactor * g / act_group_size);
845861

846862
if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) {
847863
tbl_g4_int8_float_update_impl<true, 8, 2, 8, false, true, false>(

0 commit comments

Comments
 (0)