Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
}
return true;
}
template<class T1, class T2>
std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) {
std::vector<std::pair<T1,T2>> values;
std::istringstream str_stream(str);
std::string token;
T1 first_value;
int i = 0;
while (std::getline(str_stream, token, delim)) {
std::istringstream token_stream(token);
if (i%2 == 0) {
token_stream >> first_value;
} else {
T2 value;
token_stream >> value;
values.emplace_back(first_value, value);
}
i++;
}
return values;
}
}

#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
Expand Down Expand Up @@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true;
return true;
}
if (arg == "-ser" || arg == "--smart-expert-reduction") {
CHECK_ARG
auto values = string_split_pairs<int,float>(argv[i], ',');
if (values.size() == 1) {
params.min_experts = values.front().first;
params.thresh_experts = values.front().second;
} else {
invalid_param = true;
}
return true;
}
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
Expand Down Expand Up @@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
Expand Down Expand Up @@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
cparams.min_experts = params.min_experts;
cparams.thresh_experts = params.thresh_experts;

cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
Expand Down Expand Up @@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);

const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ struct gpt_params {
int mla_attn = 0; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache
int attn_max_batch = 0; // Max batch size to use when computing attention (only applicable if flash_attn = false)
bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models
int min_experts = -1;
float thresh_experts = 0;

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool ignore_eos = false; // ignore generated EOS tokens
Expand Down
65 changes: 63 additions & 2 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ static std::string pair_str(const std::pair<int, int> & p) {
return buf;
}

// Ser = Smart Expert Reduction
using Ser = std::pair<int,float>;

struct cmd_params {
std::vector<std::string> model;
std::vector<int> n_prompt;
Expand All @@ -234,6 +237,7 @@ struct cmd_params {
std::vector<bool> flash_attn;
std::vector<int> mla_attn;
std::vector<int> attn_max_batch;
std::vector<Ser> ser;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
Expand Down Expand Up @@ -267,6 +271,7 @@ static const cmd_params cmd_params_defaults = {
/* flash_attn */ {false},
/* mla_attn */ {0},
/* attn_max_batch */ {0},
/* ser */ {{-1,0.0f}},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
Expand Down Expand Up @@ -304,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -ser, --smart-expert-reduction <i,f>(default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
Expand Down Expand Up @@ -387,6 +393,28 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
}
return true;
}
template<class T1, class T2>
std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) {
std::vector<std::pair<T1,T2>> values;
std::istringstream str_stream(str);
std::string token;
T1 first_value;
int i = 0;
while (std::getline(str_stream, token, delim)) {
std::istringstream token_stream(token);
if (i%2 == 0) {
token_stream >> first_value;
if (token_stream.fail()) return {};
} else {
T2 value;
token_stream >> value;
if (token_stream.fail()) return {};
values.emplace_back(first_value, value);
}
i++;
}
return values;
}
}

static cmd_params parse_cmd_params(int argc, char ** argv) {
Expand Down Expand Up @@ -588,6 +616,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<int>(argv[i], split_delim);
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
} else if (arg == "-ser" || arg == "--smart-expert-reduction") {
if (++i >= argc) {
invalid_param = true;
break;
}
auto p = string_split_pairs<int,float>(argv[i], split_delim);
params.ser.insert(params.ser.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -701,6 +736,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
Expand Down Expand Up @@ -739,6 +775,7 @@ struct cmd_params_instance {
bool flash_attn;
int mla_attn;
int attn_max_batch;
Ser ser;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -787,6 +824,8 @@ struct cmd_params_instance {
cparams.mla_attn = mla_attn;
cparams.attn_max_batch = attn_max_batch;
cparams.fused_moe_up_gate = fmoe;
cparams.min_experts = ser.first;
cparams.thresh_experts = ser.second;
cparams.embeddings = embeddings;

return cparams;
Expand All @@ -813,6 +852,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & fa : params.flash_attn)
for (const auto & mla : params.mla_attn)
for (const auto & amb : params.attn_max_batch)
for (const auto & ser : params.ser)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
Expand All @@ -836,6 +876,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -868,6 +909,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -900,6 +942,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -932,6 +975,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .flash_attn = */ fa,
/* .mla_attn = */ mla,
/* .attn_max_b = */ amb,
/* .ser = */ ser,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
Expand Down Expand Up @@ -975,6 +1019,7 @@ struct test {
bool flash_attn;
int mla_attn;
int attn_max_batch;
Ser ser;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
Expand Down Expand Up @@ -1007,6 +1052,7 @@ struct test {
flash_attn = inst.flash_attn;
mla_attn = inst.mla_attn;
attn_max_batch = inst.attn_max_batch;
ser = inst.ser;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
Expand Down Expand Up @@ -1101,7 +1147,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch",
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
"tensor_split", "use_mmap", "embeddings", "repack", "fused_moe",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
Expand Down Expand Up @@ -1149,6 +1195,11 @@ struct test {
tensor_split_str += "/";
}
}
auto ser_to_string = [] (const Ser& ser) {
std::ostringstream str;
str << ser.first << ',' << ser.second;
return str.str();
};
std::vector<std::string> values = {
build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan),
Expand All @@ -1158,7 +1209,8 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch),
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
Expand Down Expand Up @@ -1328,6 +1380,9 @@ struct markdown_printer : public printer {
if (field == "attn_max_batch") {
return 5;
}
if (field == "ser") {
return 10;
}
if (field == "use_mmap") {
return 4;
}
Expand Down Expand Up @@ -1371,6 +1426,9 @@ struct markdown_printer : public printer {
if (field == "attn_max_batch") {
return "amb";
}
if (field == "attn_max_batch") {
return "ser";
}
if (field == "use_mmap") {
return "mmap";
}
Expand Down Expand Up @@ -1432,6 +1490,9 @@ struct markdown_printer : public printer {
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
fields.emplace_back("attn_max_batch");
}
if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) {
fields.emplace_back("ser");
}
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}
Expand Down
13 changes: 13 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ extern "C" {
GGML_OP_ARANGE,
GGML_OP_TIMESTEP_EMBEDDING,
GGML_OP_ARGSORT,
GGML_OP_ARGSORT_THRESH,
GGML_OP_LEAKY_RELU,
GGML_OP_SOFTCAP,
GGML_OP_SOFT_CAP_MAX,
Expand Down Expand Up @@ -1913,6 +1914,12 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_sort_order order);

GGML_API struct ggml_tensor * ggml_argsort_thresh(
struct ggml_context * ctx,
struct ggml_tensor * a,
int min_entries,
float threshold);

GGML_API struct ggml_tensor * ggml_arange(
struct ggml_context * ctx,
float start,
Expand All @@ -1924,6 +1931,12 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a,
int k);
GGML_API struct ggml_tensor * ggml_top_k_thresh(
struct ggml_context * ctx,
struct ggml_tensor * a,
int k,
int min_entries,
float thresh);

#define GGML_KQ_MASK_PAD 32

Expand Down
16 changes: 12 additions & 4 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(i02 >= 0 && i02 < n_as);
if (i02 < 0 || i02 >= n_as) continue;
//GGML_ASSERT(i02 >= 0 && i02 < n_as);

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
Expand Down Expand Up @@ -2162,7 +2163,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (i02 < 0 || i02 >= n_as) continue;
//GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);

if (row_id_i != i02) {
continue;
Expand Down Expand Up @@ -2301,7 +2303,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(i02 >= 0 && i02 < n_as);
if (i02 < 0 || i02 >= n_as) continue;
//GGML_ASSERT(i02 >= 0 && i02 < n_as);

const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
Expand Down Expand Up @@ -2362,7 +2365,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);

GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i < 0 || row_id_i >= n_as) continue;
//GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);

if (row_id_i != i02) {
continue;
Expand Down Expand Up @@ -2637,6 +2641,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_ARGSORT_THRESH:
ggml_cuda_op_argsort_thresh(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
Expand Down Expand Up @@ -3252,6 +3259,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ARGSORT_THRESH:
case GGML_OP_ACC:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
Expand Down
Loading