Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true;
return true;
}
if (arg == "-ger" || arg == "--grouped-expert-routing") {
params.grouped_expert_routing = true;
return true;
}
if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
params.fused_up_gate = false;
return true;
Expand Down Expand Up @@ -1800,6 +1804,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ger, --grouped-expert-routing", "enable grouped expert routing (default: %s)", params.grouped_expert_routing ? "enabled" : "disabled" });
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
Expand Down Expand Up @@ -2755,6 +2760,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.grouped_expert_routing = params.grouped_expert_routing;
cparams.fused_up_gate = params.fused_up_gate;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;
Expand Down Expand Up @@ -3871,6 +3877,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ struct gpt_params {
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
bool fused_up_gate = true; // fused up*unary(gate) op
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
int min_experts = -1;
float thresh_experts = 0;

Expand Down
39 changes: 35 additions & 4 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ struct cmd_params {
bool warmup;
bool repack = false;
bool fmoe = false;
bool ger = false; // ger = Grouped Expert Routing
bool no_fug = false;
bool use_thp = false;
output_formats output_format;
Expand Down Expand Up @@ -296,9 +297,10 @@ static const cmd_params cmd_params_defaults = {
/* verbose */ false,
/* warmup */ true,
/* repack */ false,
/* use_thp */ false,
/* fmoe */ false,
/* ger */ false,
/* no_fug */ false,
/* use_thp */ false,
/* output_format */ MARKDOWN,
/* output_format_stderr */ NONE,
};
Expand Down Expand Up @@ -341,6 +343,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0");
printf(" -ot, --override-tensor pattern (default: none)\n");
printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0");
printf(" -ger, --grouped-expert-routing <0|1>(default: %s)\n", cmd_params_defaults.ger ? "1" : "0");
printf(" -no-fug, --no-fused-up-gate <0|1> (default: %s)\n", cmd_params_defaults.no_fug? "1" : "0");
printf("\n");
printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n");
Expand Down Expand Up @@ -739,6 +742,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
params.fmoe = std::stoi(argv[i]);
} else if (arg == "-ger" || arg == "--grouped-expert-routing") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.ger = std::stoi(argv[i]);
} else if (arg == "-no-fug" || arg == "--no-fused-up-gate") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -829,6 +838,7 @@ struct cmd_params_instance {
bool embeddings;
bool repack = false;
bool fmoe = false;
bool ger = false;
bool no_fug = false;
bool use_thp = false;
const llama_model_tensor_buft_override* buft_overrides;
Expand Down Expand Up @@ -876,6 +886,7 @@ struct cmd_params_instance {
cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
cparams.grouped_expert_routing = ger;
cparams.fused_up_gate = !no_fug;
cparams.min_experts = ser.first;
cparams.thresh_experts = ser.second;
Expand Down Expand Up @@ -935,6 +946,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(),
Expand Down Expand Up @@ -970,6 +982,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(),
Expand Down Expand Up @@ -1005,6 +1018,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(),
Expand Down Expand Up @@ -1040,6 +1054,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .embeddings = */ embd,
/* .repack = */ params.repack,
/* .fmoe = */ params.fmoe,
/* .ger = */ params.ger,
/* .no_fug = */ params.no_fug,
/* .use_thp = */ params.use_thp,
/* .buft_overrides=*/ params.buft_overrides.data(),
Expand Down Expand Up @@ -1086,6 +1101,7 @@ struct test {
bool embeddings;
bool repack = false;
bool fmoe = false;
bool ger = false;
bool no_fug = false;
bool use_thp = false;
int n_prompt;
Expand Down Expand Up @@ -1120,6 +1136,8 @@ struct test {
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
repack = inst.repack;
fmoe = inst.fmoe;
ger = inst.ger;
no_fug = inst.no_fug;
use_thp = inst.use_thp;
n_prompt = inst.n_prompt;
Expand Down Expand Up @@ -1212,7 +1230,7 @@ struct test {
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "fused_up_gate", "use_thp",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "grouped_er", "fused_up_gate", "use_thp",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
"avg_ts", "stddev_ts", "test",
Expand All @@ -1234,7 +1252,7 @@ struct test {
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
field == "fused_moe" || field == "fused_up_gate") {
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
Expand Down Expand Up @@ -1277,7 +1295,8 @@ struct test {
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
std::to_string(repack), std::to_string(fmoe), std::to_string(no_fug), std::to_string(use_thp),
std::to_string(repack), std::to_string(fmoe), std::to_string(ger),
std::to_string(no_fug), std::to_string(use_thp),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
std::to_string(avg_ts()), std::to_string(stdev_ts()),
Expand Down Expand Up @@ -1461,6 +1480,9 @@ struct markdown_printer : public printer {
if (field == "fused_moe") {
return 4;
}
if (field == "grouped_er") {
return 3;
}
if (field == "fused_up_gate") {
return 6;
}
Expand Down Expand Up @@ -1513,6 +1535,12 @@ struct markdown_printer : public printer {
if (field == "fused_moe") {
return "fmoe";
}
if (field == "grouped_er") {
return "ger";
}
if (field == "grouped_er") {
return "ger";
}
if (field == "fused_up_gate") {
return "no-fug";
}
Expand Down Expand Up @@ -1589,6 +1617,9 @@ struct markdown_printer : public printer {
if (params.fmoe != cmd_params_defaults.fmoe) {
fields.emplace_back("fused_moe");
}
if (params.ger != cmd_params_defaults.ger) {
fields.emplace_back("grouped_er");
}
if (params.no_fug != cmd_params_defaults.no_fug) {
fields.emplace_back("fused_up_gate");
}
Expand Down
8 changes: 8 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ extern "C" {
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH,
GGML_OP_GROUPED_TOPK,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
GGML_OP_SOFT_CAP_MAX,
Expand Down Expand Up @@ -2265,6 +2266,13 @@ extern "C" {
int k,
int min_entries,
float thresh);
GGML_API struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts);

#define GGML_KQ_MASK_PAD 16

Expand Down
66 changes: 64 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4253,6 +4253,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TIMESTEP_EMBEDDING",
"ARGSORT",
"ARGSORT_THRESH",
"GROUPED_TOPK",
"LEAKY_RELU",
"SOFTCAP",
"SOFT_CAP_MAX",
Expand Down Expand Up @@ -4288,7 +4289,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU",
};

static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -4356,6 +4357,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"timestep_embedding(timesteps, dim, max_period)",
"argsort(x)",
"argsort_thresh(x)",
"grouped_topk(x)",
"leaky_relu(x)",
"k2*tanh(k1*x)",
"soft_max(k2*tanh(k1*x))",
Expand Down Expand Up @@ -4391,7 +4393,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x),"
};

static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -9439,6 +9441,39 @@ struct ggml_tensor * ggml_argsort_thresh(
return result;
}

struct ggml_tensor * ggml_grouped_topk(
struct ggml_context * ctx,
struct ggml_tensor * a,
int num_groups,
int num_top_groups,
int nk,
int topk_experts) {

GGML_ASSERT(num_top_groups <= num_groups);
GGML_ASSERT(a->ne[0] % num_groups == 0);
GGML_ASSERT(a->ne[0] >= topk_experts);
int64_t n_per_group = a->ne[0] / num_groups;
GGML_ASSERT(n_per_group >= nk);

bool is_node = false;

int64_t ne[GGML_MAX_DIMS];
for (int i = 1; i < GGML_MAX_DIMS; ++i) ne[i] = a->ne[i];
ne[0] = topk_experts;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, ne);

ggml_set_op_params_i32(result, 0, num_groups);
ggml_set_op_params_i32(result, 1, num_top_groups);
ggml_set_op_params_i32(result, 2, nk);

result->op = GGML_OP_GROUPED_TOPK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;

return result;
}


// ggml_top_k

struct ggml_tensor * ggml_top_k(
Expand Down Expand Up @@ -20024,6 +20059,24 @@ static void ggml_compute_forward_argsort_thresh(
}
}

static void ggml_compute_forward_grouped_topk(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {

const struct ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
iqk_grouped_top_k(dst, params->ith, params->nth);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

// ggml_compute_forward_flash_attn_ext

static void ggml_compute_forward_flash_attn_ext_f16(
Expand Down Expand Up @@ -22521,6 +22574,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml
{
ggml_compute_forward_argsort_thresh(params, tensor);
} break;
case GGML_OP_GROUPED_TOPK:
{
ggml_compute_forward_grouped_topk(params, tensor);
} break;
case GGML_OP_LEAKY_RELU:
{
ggml_compute_forward_leaky_relu(params, tensor);
Expand Down Expand Up @@ -23539,6 +23596,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_GROUPED_TOPK:
{
GGML_ABORT("fatal error"); // TODO: not implemented
}
case GGML_OP_LEAKY_RELU:
{
GGML_ABORT("fatal error"); // TODO: not implemented
Expand Down Expand Up @@ -24281,6 +24342,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT_THRESH:
case GGML_OP_GROUPED_TOPK:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_BACK:
case GGML_OP_SSM_CONV:
Expand Down
Loading