Skip to content

Commit 9531111

Browse files
committed
Add rule-based kernel tuning.
1 parent ea2876f commit 9531111

File tree

1 file changed

+62
-31
lines changed

1 file changed

+62
-31
lines changed

ggml/src/ggml-cpu/tmac/lut_mul_mat.cpp

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace ggml::cpu::tmac {
3838

3939
/****** T-MAC properties ******/
4040
constexpr size_t kAllocAlignment = 64;
41+
const int n_threads = 8;
4142

4243
static tmac_tensor_extra * tmac_tensor_extras = nullptr;
4344
static size_t tmac_tensor_extras_index = 0;
@@ -411,48 +412,78 @@ static void ggml_tmac_tune_kernel_config(const struct ggml_tensor * tensor, int
411412

412413

413414
double min_time = 1e9;
414-
struct tmac_kernel_config best_kcfg;
415-
for (int bm: bms) {
416-
if (M % (bm/bits) != 0 || bm % bits != 0) {
417-
continue;
418-
}
419-
420-
kernel_config.bm = bm;
421-
for (int n: bns) {
422-
if ((N >= n && N % n != 0) || (N < n && n != bns[0])) {
415+
struct tmac_kernel_config best_kcfg = kernel_config;
416+
417+
auto profile_based = [&]() {
418+
for (int bm: bms) {
419+
if (M % (bm/bits) != 0 || bm % bits != 0) {
423420
continue;
424421
}
425-
426-
for (int kfactor: kfactors) {
427-
if (kfactor < kernel_config.actk) {
422+
423+
kernel_config.bm = bm;
424+
for (int n: bns) {
425+
if ((N >= n && N % n != 0) || (N < n && n != bns[0])) {
428426
continue;
429427
}
430428

431-
kernel_config.kfactor = kfactor;
432-
// insert to dict for finding
433-
insert_or_assign_tmac_kernel_config(M, K, bits, kernel_config);
434-
struct tmac_run_single_kernel_settings settings = {
435-
/* .test_time_ms = */ 5000,
436-
/* .M = */ M,
437-
/* .N = */ N,
438-
/* .K = */ K,
439-
/* .n = */ n,
440-
/* .kernel_config = */ &kernel_config
441-
};
442-
double this_time;
443-
ggml_tmac_tune_single_kernel_config(&settings, this_time);
444-
GGML_LOG_INFO("Tuned kernel config: M=%d, N=%d, K=%d, bm=%d, n=%d, kfactor=%d, bits=%d, g=%d, ngroups_per_elem=%d, q_group_size=%d, act_group_size=%d\t TIME: %.4f ms\n",
445-
M, N, K, bm, n, kfactor, bits, kernel_config.g, kernel_config.ngroups_per_elem, kernel_config.q_group_size, kernel_config.act_group_size, this_time);
446-
if (this_time < min_time) {
447-
min_time = this_time;
448-
best_kcfg = kernel_config;
429+
for (int kfactor: kfactors) {
430+
if (kfactor < kernel_config.actk) {
431+
continue;
432+
}
433+
434+
kernel_config.kfactor = kfactor;
435+
// insert to dict for finding
436+
insert_or_assign_tmac_kernel_config(M, K, bits, kernel_config);
437+
struct tmac_run_single_kernel_settings settings = {
438+
/* .test_time_ms = */ 5000,
439+
/* .M = */ M,
440+
/* .N = */ N,
441+
/* .K = */ K,
442+
/* .n = */ n,
443+
/* .kernel_config = */ &kernel_config
444+
};
445+
double this_time;
446+
ggml_tmac_tune_single_kernel_config(&settings, this_time);
447+
if (this_time < min_time) {
448+
min_time = this_time;
449+
best_kcfg = kernel_config;
450+
}
449451
}
450452
}
453+
};
454+
};
455+
auto rule_based = [&]() {
456+
float smallest_penalty = 1e9;
457+
for (int bm: bms) {
458+
if (M % (bm/bits) != 0 || bm % bits != 0) {
459+
continue;
460+
}
461+
int num_tiles = M / (bm/bits);
462+
int num_groups = (num_tiles + n_threads - 1) / n_threads;
463+
float penalty = 0.1 * num_groups + (num_groups - 1.0 * num_tiles / n_threads) / num_groups;
464+
if (penalty <= smallest_penalty) {
465+
smallest_penalty = penalty;
466+
best_kcfg.bm = bm;
467+
}
451468
}
452-
}
469+
470+
int largest_kfactor = 0;
471+
for (int kfactor: kfactors) {
472+
if (kfactor < kernel_config.actk) {
473+
continue;
474+
}
475+
if (kfactor > largest_kfactor) {
476+
largest_kfactor = kfactor;
477+
best_kcfg.kfactor = kfactor;
478+
}
479+
}
480+
};
481+
rule_based();
453482

454483
// Save the results
455484
insert_or_assign_tmac_kernel_config(M, K, bits, best_kcfg);
485+
GGML_LOG_INFO("Tuned kernel config: M=%d, N=%d, K=%d, bm=%d, kfactor=%d, bits=%d, g=%d, ngroups_per_elem=%d, q_group_size=%d, act_group_size=%d\n",
486+
M, N, K, best_kcfg.bm, best_kcfg.kfactor, bits, best_kcfg.g, best_kcfg.ngroups_per_elem, best_kcfg.q_group_size, best_kcfg.act_group_size);
456487
}
457488

458489

0 commit comments

Comments
 (0)