Skip to content

Commit 41bde27

Browse files
ikawrakowIwan Kawrakow
andauthored
Graph reuse (#947)
* Add mainline compatible FA command line option * Graph reuse: add command line argument to turn it on * WIP * This seems to work * This is perhaps cleaner * Change the command line option to -gr --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent be1a8cb commit 41bde27

File tree

9 files changed

+174
-38
lines changed

9 files changed

+174
-38
lines changed

common/common.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,6 +1135,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
11351135
params.flash_attn = false;
11361136
return true;
11371137
}
1138+
11381139
if (arg == "-fa" || arg == "--flash-attn") {
11391140
CHECK_ARG
11401141
std::string next_arg{argv[i]};
@@ -1180,6 +1181,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
11801181
params.rope_cache = true;
11811182
return true;
11821183
}
1184+
if (arg == "-gr" || arg == "--graph-reuse") {
1185+
params.graph_reuse = true;
1186+
return true;
1187+
}
11831188
if (arg == "-ser" || arg == "--smart-expert-reduction") {
11841189
CHECK_ARG
11851190
auto values = string_split_pairs<int,float>(argv[i], ',');
@@ -2004,6 +2009,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
20042009
options.push_back({ "*", "-no-fug, --no-fused-up-gate", "disaable fused up-gate (default: %s)", params.fused_up_gate ? "enabled" : "disabled" });
20052010
options.push_back({ "*", "-no-mmad, --no-fused-mul-multiadd", "disaable fused mul-multi_add (default: %s)", params.fused_mmad? "enabled" : "disabled" });
20062011
options.push_back({ "*", "-rcache, --rope-cache", "enable RoPE cache (default: %s)", params.rope_cache ? "enabled" : "disabled" });
2012+
options.push_back({ "*", "-gr, --graph-reuse", "enable graph reuse (default: %s)", params.graph_reuse ? "enabled" : "disabled" });
20072013
options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
20082014
options.push_back({ "*", "-mqkv, --merge-qkv,", "merge Q,K,V (default: %d)", params.merge_qkv});
20092015
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
@@ -2979,6 +2985,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
29792985
cparams.fused_up_gate = params.fused_up_gate;
29802986
cparams.fused_mmad = params.fused_mmad;
29812987
cparams.rope_cache = params.rope_cache;
2988+
cparams.graph_reuse = params.graph_reuse;
29822989
cparams.min_experts = params.min_experts;
29832990
cparams.thresh_experts = params.thresh_experts;
29842991
cparams.only_active_experts = params.only_active_exps;
@@ -4123,7 +4130,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
41234130
fprintf(stream, "grouped_expert_routing: %s # default: false\n", params.grouped_expert_routing ? "true" : "false");
41244131
fprintf(stream, "fused_up_gate: %s # default: true\n", params.fused_up_gate ? "true" : "false");
41254132
fprintf(stream, "fused_mmad: %s # default: true\n", params.fused_mmad ? "true" : "false");
4126-
fprintf(stream, "rope_cache: %s # default: true\n", params.rope_cache ? "true" : "false");
4133+
fprintf(stream, "rope_cache: %s # default: false\n", params.rope_cache ? "true" : "false");
4134+
fprintf(stream, "graph_reuse: %s # default: false\n", params.graph_reuse ? "true" : "false");
41274135
fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
41284136
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
41294137

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ struct gpt_params {
254254
bool fused_mmad = true; // fused mul+multi_add op
255255
bool grouped_expert_routing = false; // if to use grouped expert routing (BailingMoeV2 arch)
256256
bool rope_cache = false; // if to use RoPE cache (for supported models)
257+
bool graph_reuse = false; // if to reuse compute graphs
257258
int min_experts = -1;
258259
float thresh_experts = 0;
259260

examples/llama-bench/llama-bench.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ struct cmd_params {
251251
std::vector<int> mla_attn;
252252
std::vector<int> attn_max_batch;
253253
std::vector<Ser> ser;
254+
std::vector<bool> reuse;
254255
std::vector<std::vector<float>> tensor_split;
255256
std::vector<bool> use_mmap;
256257
std::vector<bool> embeddings;
@@ -292,6 +293,7 @@ static const cmd_params cmd_params_defaults = {
292293
/* mla_attn */ {3},
293294
/* attn_max_batch */ {0},
294295
/* ser */ {{-1,0.0f}},
296+
/* reuse */ {false},
295297
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
296298
/* use_mmap */ {true},
297299
/* embeddings */ {false},
@@ -339,6 +341,7 @@ static void print_usage(int /* argc */, char ** argv) {
339341
printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
340342
printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
341343
printf(" -ser, --smart-expert-reduction <i,f>(default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str());
344+
printf(" -gr, --graph-reuse <0|1> (default: %s)\n", join(cmd_params_defaults.reuse, ",").c_str());
342345
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
343346
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
344347
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -681,6 +684,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
681684
}
682685
auto p = string_split<int>(argv[i], split_delim);
683686
params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end());
687+
} else if (arg == "-gr" || arg == "--graph-reuse") {
688+
if (++i >= argc) {
689+
invalid_param = true;
690+
break;
691+
}
692+
auto p = string_split<bool>(argv[i], split_delim);
693+
params.reuse.insert(params.reuse.end(), p.begin(), p.end());
684694
} else if (arg == "-ser" || arg == "--smart-expert-reduction") {
685695
if (++i >= argc) {
686696
invalid_param = true;
@@ -852,6 +862,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
852862
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
853863
if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
854864
if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; }
865+
if (params.reuse.empty()) { params.reuse = cmd_params_defaults.reuse; }
855866
if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; }
856867
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
857868
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
@@ -891,6 +902,7 @@ struct cmd_params_instance {
891902
bool flash_attn;
892903
int mla_attn;
893904
int attn_max_batch;
905+
bool reuse;
894906
Ser ser;
895907
std::vector<float> tensor_split;
896908
std::string cuda_params;
@@ -950,6 +962,7 @@ struct cmd_params_instance {
950962
cparams.flash_attn = flash_attn;
951963
cparams.mla_attn = mla_attn;
952964
cparams.attn_max_batch = attn_max_batch;
965+
cparams.graph_reuse = reuse;
953966
cparams.fused_moe_up_gate = fmoe;
954967
cparams.grouped_expert_routing = ger;
955968
cparams.rope_cache = rcache;
@@ -984,6 +997,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
984997
for (const auto & fa : params.flash_attn)
985998
for (const auto & mla : params.mla_attn)
986999
for (const auto & amb : params.attn_max_batch)
1000+
for (const auto & reuse : params.reuse)
9871001
for (const auto & ser : params.ser)
9881002
for (const auto & nt : params.n_threads) {
9891003
for (const auto & n_prompt : params.n_prompt) {
@@ -1008,6 +1022,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10081022
/* .flash_attn = */ fa,
10091023
/* .mla_attn = */ mla,
10101024
/* .attn_max_b = */ amb,
1025+
/* .reuse = */ reuse,
10111026
/* .ser = */ ser,
10121027
/* .tensor_split = */ ts,
10131028
/* .cuda_params = */ params.cuda_params,
@@ -1048,6 +1063,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10481063
/* .flash_attn = */ fa,
10491064
/* .mla_attn = */ mla,
10501065
/* .attn_max_b = */ amb,
1066+
/* .reuse = */ reuse,
10511067
/* .ser = */ ser,
10521068
/* .tensor_split = */ ts,
10531069
/* .cuda_params = */ params.cuda_params,
@@ -1088,6 +1104,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10881104
/* .flash_attn = */ fa,
10891105
/* .mla_attn = */ mla,
10901106
/* .attn_max_b = */ amb,
1107+
/* .reuse = */ reuse,
10911108
/* .ser = */ ser,
10921109
/* .tensor_split = */ ts,
10931110
/* .cuda_params = */ params.cuda_params,
@@ -1128,6 +1145,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
11281145
/* .flash_attn = */ fa,
11291146
/* .mla_attn = */ mla,
11301147
/* .attn_max_b = */ amb,
1148+
/* .reuse = */ reuse,
11311149
/* .ser = */ ser,
11321150
/* .tensor_split = */ ts,
11331151
/* .cuda_params = */ params.cuda_params,
@@ -1179,6 +1197,7 @@ struct test {
11791197
bool flash_attn;
11801198
int mla_attn;
11811199
int attn_max_batch;
1200+
bool reuse;
11821201
Ser ser;
11831202
std::vector<float> tensor_split;
11841203
std::string cuda_params;
@@ -1219,6 +1238,7 @@ struct test {
12191238
flash_attn = inst.flash_attn;
12201239
mla_attn = inst.mla_attn;
12211240
attn_max_batch = inst.attn_max_batch;
1241+
reuse = inst.reuse;
12221242
ser = inst.ser;
12231243
tensor_split = inst.tensor_split;
12241244
cuda_params = inst.cuda_params;
@@ -1321,7 +1341,7 @@ struct test {
13211341
"n_batch", "n_ubatch",
13221342
"n_threads", "type_k", "type_v",
13231343
"n_gpu_layers", "split_mode",
1324-
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser",
1344+
"main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "reuse",
13251345
"tensor_split", "use_mmap", "embeddings", "repack", "mqkv", "fused_moe", "grouped_er",
13261346
"fused_up_gate", "use_thp", "ooae", "rcache",
13271347
"n_prompt", "n_gen", "test_time",
@@ -1346,7 +1366,7 @@ struct test {
13461366
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
13471367
field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "use_thp" ||
13481368
field == "fused_moe" || field == "grouped_er" || field == "fused_up_gate" || field == "ooae" || field == "mqkv" ||
1349-
field == "rcache") {
1369+
field == "rcache" || field == "reuse") {
13501370
return BOOL;
13511371
}
13521372
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1387,7 +1407,7 @@ struct test {
13871407
std::to_string(is_gen ? n_threads.first : n_threads.second), ggml_type_name(type_k), ggml_type_name(type_v),
13881408
std::to_string(n_gpu_layers), split_mode_str(split_mode),
13891409
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
1390-
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser),
1410+
std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), std::to_string(reuse),
13911411
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
13921412
std::to_string(repack), std::to_string(fmoe), std::to_string(ger), std::to_string(rcache),
13931413
std::to_string(no_fug), std::to_string(use_thp), std::to_string(no_ooae), std::to_string(mqkv),
@@ -1559,6 +1579,9 @@ struct markdown_printer : public printer {
15591579
if (field == "attn_max_batch") {
15601580
return 5;
15611581
}
1582+
if (field == "reuse") {
1583+
return 2;
1584+
}
15621585
if (field == "ser") {
15631586
return 10;
15641587
}
@@ -1623,7 +1646,10 @@ struct markdown_printer : public printer {
16231646
if (field == "attn_max_batch") {
16241647
return "amb";
16251648
}
1626-
if (field == "attn_max_batch") {
1649+
if (field == "reuse") {
1650+
return "gr";
1651+
}
1652+
if (field == "ser") {
16271653
return "ser";
16281654
}
16291655
if (field == "use_mmap") {
@@ -1702,9 +1728,12 @@ struct markdown_printer : public printer {
17021728
if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
17031729
fields.emplace_back("mla_attn");
17041730
}
1705-
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) {
1731+
if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.attn_max_batch) {
17061732
fields.emplace_back("attn_max_batch");
17071733
}
1734+
if (params.reuse.size() > 1 || params.reuse != cmd_params_defaults.reuse) {
1735+
fields.emplace_back("reuse");
1736+
}
17081737
if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) {
17091738
fields.emplace_back("ser");
17101739
}

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ extern "C" {
429429
bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL]
430430
bool fused_mmad; // whether to use fused mul+multi_add op [EXPERIMENTAL]
431431
bool rope_cache; // whether to use RoPE cache [EXPERIMENTAL]
432+
bool graph_reuse; // whether to reuse graphs when possible [EXPERIMENTAL]
432433
int min_experts;
433434
float thresh_experts;
434435
bool only_active_experts;

src/llama-build-context.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ ggml_tensor * llm_build_context::llm_build_inp_embd(
469469
}
470470

471471
void llm_build_context::llm_build_kv_store(
472+
struct llama_context & lctx,
472473
struct ggml_context * ctx,
473474
const llama_hparams & hparams,
474475
const llama_cparams & cparams,
@@ -494,29 +495,36 @@ void llm_build_context::llm_build_kv_store(
494495
// (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
495496
//cb(k_cache_view, "k_cache_view", il);
496497

498+
GGML_ASSERT(2*il+1 < (int)lctx.cache_copies.size());
497499
auto k_row_size = ggml_row_size(kv.k_l[il]->type, n_embd_head_k);
498500
ggml_tensor * k_cache_view = ggml_view_2d(ctx, kv.k_l[il], n_embd_head_k, n_tokens*n_head_kv,
499501
k_row_size, k_row_size*n_head_kv*kv_head);
500502

503+
lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx, k_cur, k_cache_view);
504+
lctx.cache_copies[2*il+0].step = k_row_size*n_head_kv;
505+
501506
// note: storing RoPE-ed version of K in the KV cache
502-
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
507+
ggml_build_forward_expand(graph, lctx.cache_copies[2*il+0].cpy);
503508

504509
struct ggml_tensor * v_cache_view = nullptr;
505510

506511
if (cparams.flash_attn) {
507512
v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
508513
(kv_head)*ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa));
514+
lctx.cache_copies[2*il+1].step = ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa);
509515
} else {
510516
// note: the V cache is transposed when not using flash attention
511517
v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
512518
( n_ctx)*ggml_element_size(kv.v_l[il]),
513519
(kv_head)*ggml_element_size(kv.v_l[il]));
520+
lctx.cache_copies[2*il+1].step = ggml_element_size(kv.v_l[il]);
514521

515522
v_cur = ggml_transpose(ctx, v_cur);
516523
}
517524
cb(v_cache_view, "v_cache_view", il);
518525

519-
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
526+
lctx.cache_copies[2*il+1].cpy = ggml_cpy(ctx, v_cur, v_cache_view);
527+
ggml_build_forward_expand(graph, lctx.cache_copies[2*il+1].cpy);
520528
}
521529

522530
ggml_tensor * llm_build_context::llm_build_lora_mm(
@@ -1205,7 +1213,7 @@ ggml_tensor * llm_build_context::llm_build_kv(
12051213
ggml_build_forward_expand(graph, k_cur);
12061214
ggml_build_forward_expand(graph, v_cur);
12071215

1208-
llm_build_kv_store(ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
1216+
llm_build_kv_store(lctx, ctx, hparams, cparams, kv, graph, k_cur, v_cur, n_tokens, kv_head, cb, il);
12091217

12101218
struct ggml_tensor * cur;
12111219

@@ -6045,7 +6053,9 @@ ggml_cgraph * llm_build_context::build_deepseek2() {
60456053
auto row_size = ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
60466054
ggml_tensor * kv_cache_view = ggml_view_2d(ctx0, kv_self.k_l[il], kv_self.k_l[il]->ne[0], n_tokens,
60476055
row_size, row_size*kv_head);
6048-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kvr, kv_cache_view));
6056+
lctx.cache_copies[2*il+0].cpy = ggml_cpy(ctx0, kvr, kv_cache_view);
6057+
lctx.cache_copies[2*il+0].step = row_size;
6058+
ggml_build_forward_expand(gf, lctx.cache_copies[2*il+0].cpy);
60496059
ggml_tensor * kv_cache = ggml_view_2d(ctx0, kv_self.k_l[il],
60506060
kv_lora_rank + n_embd_head_qk_rope, n_kv,
60516061
ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope), 0);
@@ -7082,7 +7092,7 @@ ggml_cgraph * llm_build_context::build_t5_decoder() {
70827092
model.layers[il].wk, nullptr,
70837093
model.layers[il].wv, nullptr, 0, il);
70847094

7085-
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
7095+
llm_build_kv_store(lctx, ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
70867096

70877097
struct ggml_tensor * k =
70887098
ggml_view_3d(ctx0, kv_self.k_l[il],

src/llama-build-context.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ struct llm_build_context {
292292
llm_norm_type type,
293293
const llm_build_cb & cb, int il, float scale_eps = 1);
294294

295-
static void llm_build_kv_store(ggml_context * ctx, const llama_hparams & hparams,
295+
static void llm_build_kv_store(llama_context & lctx, ggml_context * ctx, const llama_hparams & hparams,
296296
const llama_cparams & cparams,
297297
const llama_kv_cache & kv,
298298
ggml_cgraph * graph,

src/llama-context.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ struct llama_model;
99
#include <vector>
1010
#include <map>
1111
#include <set>
12+
#include <memory>
1213

1314
struct llama_kv_cell {
1415
llama_pos pos = -1;
@@ -205,4 +206,18 @@ struct llama_context {
205206

206207
ggml_backend_t ggml_backend_by_name(const char * name);
207208

209+
struct Prev;
210+
std::unique_ptr<Prev> prev;
211+
212+
void reset_scheduler();
213+
bool can_reuse_graph(const llama_batch & u_batch);
214+
215+
struct CacheCopy {
216+
ggml_tensor * cpy = nullptr;
217+
size_t step = 0;
218+
};
219+
std::vector<CacheCopy> cache_copies;
220+
221+
bool update_cache_copies();
222+
208223
};

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct llama_cparams {
3838
bool fused_up_gate;
3939
bool fused_mmad;
4040
bool rope_cache;
41+
bool graph_reuse;
4142
int min_experts;
4243
float thresh_experts;
4344

0 commit comments

Comments
 (0)