diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 3e948e68b..afa46d2c6 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -265,6 +265,7 @@ struct cmd_params { bool no_fug = false; bool use_thp = false; bool no_ooae = false; + bool mqkv = false; output_formats output_format; output_formats output_format_stderr; }; @@ -303,6 +304,7 @@ static const cmd_params cmd_params_defaults = { /* no_fug */ false, /* use_thp */ false, /* no_ooae */ false, + /* mqkv */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -342,6 +344,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -mqkv, --merge-qkv (default: %s)\n", cmd_params_defaults.mqkv ? "1" : "0"); printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -ot, --override-tensor pattern (default: none)\n"); printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); @@ -733,6 +736,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-mqkv" || arg == "--merge-qkv") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.mqkv = std::stoi(argv[i]); } else if (arg == "-thp" || arg == "--transparent-huge-pages") { if (++i >= argc) { invalid_param = true; @@ -851,6 +860,7 @@ struct cmd_params_instance { bool no_fug = false; bool use_thp = false; bool no_ooae = false; + bool mqkv = false; const llama_model_tensor_buft_override* buft_overrides; llama_model_params to_llama_mparams() const { @@ -866,6 +876,7 @@ struct cmd_params_instance { mparams.use_mmap = use_mmap; mparams.repack_tensors = repack; mparams.use_thp = use_thp; + mparams.merge_qkv = mqkv; mparams.tensor_buft_overrides = buft_overrides; return mparams; @@ -879,6 +890,7 @@ struct cmd_params_instance { main_gpu == other.main_gpu && use_mmap == other.use_mmap && repack == other.repack && + mqkv == other.mqkv && use_thp == other.use_thp && tensor_split == other.tensor_split; } @@ -961,6 +973,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, + /* .mqkv = */ params.mqkv, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -998,6 +1011,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, + /* .mqkv = */ params.mqkv, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1035,6 +1049,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, /* .no_ooae = */ params.no_ooae, + /* .mqkv = */ params.mqkv, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1071,7 +1086,8 @@ static std::vector get_cmd_params_instances(const cmd_param /* .ger = */ params.ger, /* .no_fug = */ params.no_fug, /* .use_thp = */ params.use_thp, - /* .no_ooae = */ params.no_ooae, + /* .no_ooae = */ params.no_ooae, + /* .mqkv = */ params.mqkv, /* .buft_overrides=*/ params.buft_overrides.data(), }; instances.push_back(instance); @@ -1120,6 +1136,7 @@ struct test { bool no_fug = false; bool use_thp = false; bool no_ooae = false; + bool mqkv = false; int n_prompt; int n_gen; std::string test_time; @@ -1152,6 +1169,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + mqkv = inst.mqkv; fmoe = inst.fmoe; ger = inst.ger; no_fug = inst.no_fug; @@ -1247,7 +1265,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", - "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae", + "tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er", "fused_up_gate", "use_thp", "ooae", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1269,7 +1287,7 @@ struct test { if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" || - field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae") { + field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1313,7 +1331,7 @@ struct test { std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(ger), - std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), + std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1491,6 +1509,9 @@ struct markdown_printer : public printer { if (field == "repack") { return 3; } + if (field == "mqkv") { + return 4; + } if (field == "use_thp") { return 3; } @@ -1549,6 +1570,9 @@ struct markdown_printer : public printer { if (field == "repack") { return "rtr"; } + if (field == "mqkv") { + return "mqkv"; + } if (field == "use_thp") { return "thp"; } @@ -1634,6 +1658,9 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.mqkv != cmd_params_defaults.mqkv) { + fields.emplace_back("mqkv"); + } if (params.use_thp != cmd_params_defaults.use_thp) { fields.emplace_back("use_thp"); } diff --git a/ggml/src/ggml-cuda/mmvq-templates.cuh b/ggml/src/ggml-cuda/mmvq-templates.cuh index 6b4a055f4..f9df3c9d8 100644 --- a/ggml/src/ggml-cuda/mmvq-templates.cuh +++ b/ggml/src/ggml-cuda/mmvq-templates.cuh @@ -112,6 +112,10 @@ static __device__ void mul_mat_vec_q( } } + float local_bias[rows_per_cuda_block] = { 0.0f }; + if (bias && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) { + local_bias[threadIdx.x] = bias[row0 + threadIdx.x]; + } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; if (threadIdx.y > 0) { #pragma unroll @@ -140,7 +144,7 @@ static __device__ void mul_mat_vec_q( } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { - dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x]; + dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x] + local_bias[threadIdx.x]; } } } @@ -176,6 +180,14 @@ static __device__ void fused_mul_mat_vec_q( // partial sum for each thread float tmp_u[ncols_y][rows_per_cuda_block] = {0.0f}; float tmp_g[ncols_y][rows_per_cuda_block] = {0.0f}; + float local_bias_u[rows_per_cuda_block] = { 0.0f }; + float local_bias_g[rows_per_cuda_block] = { 0.0f }; + if (bias_u && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) { + local_bias_u[threadIdx.x] = bias_u[row0 + threadIdx.x]; + } + if (bias_g && threadIdx.y == 0 && threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) { + local_bias_g[threadIdx.x] = bias_g[row0 + threadIdx.x]; + } const block_q8_1 * y = (const block_q8_1 *) vy; @@ -242,8 +254,8 @@ static __device__ void fused_mul_mat_vec_q( default: { constexpr float alpha = 1.702f; constexpr float limit = 7.0f; - g += bias_g[j*nrows_dst + row0 + threadIdx.x]; - u += bias_u[j*nrows_dst + row0 + threadIdx.x]; + g += local_bias_g[threadIdx.x]; + u += local_bias_u[threadIdx.x]; g = fminf(g, limit); u = fmaxf(fminf(u, limit), -limit); r = g / (1.0f + expf(-g * alpha)) * (1.0f + u);