Skip to content

Commit fd3757d

Browse files
ikawrakowIwan Kawrakow
andauthored
Biased mmvq: minor optimization (#880)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent c7dbe3f commit fd3757d

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ struct cmd_params {
265265
bool no_fug = false;
266266
bool use_thp = false;
267267
bool no_ooae = false;
268+
bool mqkv = false;
268269
output_formats output_format;
269270
output_formats output_format_stderr;
270271
};
@@ -303,6 +304,7 @@ static const cmd_params cmd_params_defaults = {
303304
/* no_fug */ false,
304305
/* use_thp */ false,
305306
/* no_ooae */ false,
307+
/* mqkv */ false,
306308
/* output_format */ MARKDOWN,
307309
/* output_format_stderr */ NONE,
308310
};
@@ -342,6 +344,7 @@ static void print_usage(int /* argc */, char ** argv) {
342344
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
343345
printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0");
344346
printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0");
347+
printf(" -mqkv, --merge-qkv (default: %s)\n", cmd_params_defaults.mqkv ? "1" : "0");
345348
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
346349
printf(" -ot, --override-tensor pattern (default: none)\n");
347350
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
@@ -733,6 +736,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
733736
break;
734737
}
735738
params.repack = std::stoi(argv[i]);
739+
} else if (arg == "-mqkv" || arg == "--merge-qkv") {
740+
if (++i >= argc) {
741+
invalid_param = true;
742+
break;
743+
}
744+
params.mqkv = std::stoi(argv[i]);
736745
} else if (arg == "-thp" || arg == "--transparent-huge-pages") {
737746
if (++i >= argc) {
738747
invalid_param = true;
@@ -851,6 +860,7 @@ struct cmd_params_instance {
851860
bool no_fug = false;
852861
bool use_thp = false;
853862
bool no_ooae = false;
863+
bool mqkv = false;
854864
const llama_model_tensor_buft_override* buft_overrides;
855865

856866
llama_model_params to_llama_mparams() const {
@@ -866,6 +876,7 @@ struct cmd_params_instance {
866876
mparams.use_mmap = use_mmap;
867877
mparams.repack_tensors = repack;
868878
mparams.use_thp = use_thp;
879+
mparams.merge_qkv = mqkv;
869880
mparams.tensor_buft_overrides = buft_overrides;
870881

871882
return mparams;
@@ -879,6 +890,7 @@ struct cmd_params_instance {
879890
main_gpu == other.main_gpu &&
880891
use_mmap == other.use_mmap &&
881892
repack == other.repack &&
893+
mqkv == other.mqkv &&
882894
use_thp == other.use_thp &&
883895
tensor_split == other.tensor_split;
884896
}
@@ -961,6 +973,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
961973
/* .no_fug = */ params.no_fug,
962974
/* .use_thp = */ params.use_thp,
963975
/* .no_ooae = */ params.no_ooae,
976+
/* .mqkv = */ params.mqkv,
964977
/* .buft_overrides=*/ params.buft_overrides.data(),
965978
};
966979
instances.push_back(instance);
@@ -998,6 +1011,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
9981011
/* .no_fug = */ params.no_fug,
9991012
/* .use_thp = */ params.use_thp,
10001013
/* .no_ooae = */ params.no_ooae,
1014+
/* .mqkv = */ params.mqkv,
10011015
/* .buft_overrides=*/ params.buft_overrides.data(),
10021016
};
10031017
instances.push_back(instance);
@@ -1035,6 +1049,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10351049
/* .no_fug = */ params.no_fug,
10361050
/* .use_thp = */ params.use_thp,
10371051
/* .no_ooae = */ params.no_ooae,
1052+
/* .mqkv = */ params.mqkv,
10381053
/* .buft_overrides=*/ params.buft_overrides.data(),
10391054
};
10401055
instances.push_back(instance);
@@ -1071,7 +1086,8 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10711086
/* .ger = */ params.ger,
10721087
/* .no_fug = */ params.no_fug,
10731088
/* .use_thp = */ params.use_thp,
1074-
/* .no_ooae = */ params.no_ooae,
1089+
/* .no_ooae = */ params.no_ooae,
1090+
/* .mqkv = */ params.mqkv,
10751091
/* .buft_overrides=*/ params.buft_overrides.data(),
10761092
};
10771093
instances.push_back(instance);
@@ -1120,6 +1136,7 @@ struct test {
11201136
bool no_fug = false;
11211137
bool use_thp = false;
11221138
bool no_ooae = false;
1139+
bool mqkv = false;
11231140
int n_prompt;
11241141
int n_gen;
11251142
std::string test_time;
@@ -1152,6 +1169,7 @@ struct test {
11521169
use_mmap = inst.use_mmap;
11531170
embeddings = inst.embeddings;
11541171
repack = inst.repack;
1172+
mqkv = inst.mqkv;
11551173
fmoe = inst.fmoe;
11561174
ger = inst.ger;
11571175
no_fug = inst.no_fug;
@@ -1247,7 +1265,7 @@ struct test {
12471265
"n_threads", "type_k", "type_v",
12481266
"n_gpu_layers", "split_mode",
12491267
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1250-
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae",
1268+
"tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae",
12511269
"n_prompt", "n_gen", "test_time",
12521270
"avg_ns", "stddev_ns",
12531271
"avg_ts", "stddev_ts", "test",
@@ -1269,7 +1287,7 @@ struct test {
12691287
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
12701288
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
12711289
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
1272-
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae") {
1290+
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv") {
12731291
return BOOL;
12741292
}
12751293
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1313,7 +1331,7 @@ struct test {
13131331
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
13141332
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
13151333
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
1316-
std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae),
1334+
std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv),
13171335
std::to_string(n_prompt), std::to_string(n_gen), test_time,
13181336
std::to_string(avg_ns()), std::to_string(stdev_ns()),
13191337
std::to_string(avg_ts()), std::to_string(stdev_ts()),
@@ -1491,6 +1509,9 @@ struct markdown_printer : public printer {
14911509
if (field == "repack") {
14921510
return 3;
14931511
}
1512+
if (field == "mqkv") {
1513+
return 4;
1514+
}
14941515
if (field == "use_thp") {
14951516
return 3;
14961517
}
@@ -1549,6 +1570,9 @@ struct markdown_printer : public printer {
15491570
if (field == "repack") {
15501571
return "rtr";
15511572
}
1573+
if (field == "mqkv") {
1574+
return "mqkv";
1575+
}
15521576
if (field == "use_thp") {
15531577
return "thp";
15541578
}
@@ -1634,6 +1658,9 @@ struct markdown_printer : public printer {
16341658
if (params.repack != cmd_params_defaults.repack) {
16351659
fields.emplace_back("repack");
16361660
}
1661+
if (params.mqkv != cmd_params_defaults.mqkv) {
1662+
fields.emplace_back("mqkv");
1663+
}
16371664
if (params.use_thp != cmd_params_defaults.use_thp) {
16381665
fields.emplace_back("use_thp");
16391666
}

ggml/src/ggml-cuda/mmvq-templates.cuh

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ static __device__ void mul_mat_vec_q(
112112
}
113113
}
114114

115+
float local_bias[rows_per_cuda_block] = { 0.0f };
116+
if (bias && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
117+
local_bias[threadIdx.x] = bias[row0 + threadIdx.x];
118+
}
115119
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
116120
if (threadIdx.y > 0) {
117121
#pragma unroll
@@ -140,7 +144,7 @@ static __device__ void mul_mat_vec_q(
140144
}
141145

142146
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
143-
dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x];
147+
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x] + local_bias[threadIdx.x];
144148
}
145149
}
146150
}
@@ -176,6 +180,14 @@ static __device__ void fused_mul_mat_vec_q(
176180
// partial sum for each thread
177181
float tmp_u[ncols_y][rows_per_cuda_block] = {0.0f};
178182
float tmp_g[ncols_y][rows_per_cuda_block] = {0.0f};
183+
float local_bias_u[rows_per_cuda_block] = { 0.0f };
184+
float local_bias_g[rows_per_cuda_block] = { 0.0f };
185+
if (bias_u && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
186+
local_bias_u[threadIdx.x] = bias_u[row0 + threadIdx.x];
187+
}
188+
if (bias_g && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
189+
local_bias_g[threadIdx.x] = bias_g[row0 + threadIdx.x];
190+
}
179191

180192
const block_q8_1 * y = (const block_q8_1 *) vy;
181193

@@ -242,8 +254,8 @@ static __device__ void fused_mul_mat_vec_q(
242254
default: {
243255
constexpr float alpha = 1.702f;
244256
constexpr float limit = 7.0f;
245-
g += bias_g[j*nrows_dst + row0 + threadIdx.x];
246-
u += bias_u[j*nrows_dst + row0 + threadIdx.x];
257+
g += local_bias_g[threadIdx.x];
258+
u += local_bias_u[threadIdx.x];
247259
g = fminf(g, limit);
248260
u = fmaxf(fminf(u, limit), -limit);
249261
r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);

0 commit comments

Comments
 (0)