Skip to content

Commit 2e74787

Browse files
committed
Add --disable-op-offload
1 parent 8c83449 commit 2e74787

File tree

11 files changed

+65
-17
lines changed

11 files changed

+65
-17
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24362436
}
24372437
}
24382438
));
2439+
add_opt(common_arg(
2440+
{"--disable-op-offload"},
2441+
string_format("disable offloading host tensor operations to device (default: %s)", params.disable_op_offload ? "true" : "false"),
2442+
[](common_params & params) {
2443+
params.disable_op_offload = true;
2444+
}
2445+
));
24392446
add_opt(common_arg(
24402447
{"--lora"}, "FNAME",
24412448
"path to LoRA adapter (can be repeated to use multiple adapters)",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11131113
cparams.offload_kqv = !params.no_kv_offload;
11141114
cparams.flash_attn = params.flash_attn;
11151115
cparams.no_perf = params.no_perf;
1116+
cparams.disable_op_offload= params.disable_op_offload;
11161117

11171118
if (params.reranking) {
11181119
cparams.embeddings = true;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ struct common_params {
332332
bool no_kv_offload = false; // disable KV offloading
333333
bool warmup = true; // warmup run
334334
bool check_tensors = false; // validate tensor data
335+
bool disable_op_offload= false; // globally disable offload host tensor operations to device
335336

336337
bool single_turn = false; // single turn chat conversation
337338

ggml/include/ggml-backend.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ extern "C" {
248248
// preferrably to run on the same backend as the buffer
249249
ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
250250
251-
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false);
251+
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, false);
252252
253253
// initialize buffers from a max size graph (optional)
254254
reserve_graph = build_graph(sched, max_batch_size);
@@ -289,7 +289,7 @@ extern "C" {
289289
typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
290290

291291
// Initialize a backend scheduler, backends with low index are given priority over backends with high index
292-
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
292+
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool disable_op_offload);
293293
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
294294

295295
// Initialize backend buffers from a measure graph

ggml/src/ggml-backend.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,8 @@ struct ggml_backend_sched {
674674
char * context_buffer;
675675
size_t context_buffer_size;
676676

677+
bool disable_op_offload;
678+
677679
int debug;
678680
};
679681

@@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
766768
if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
767769
int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
768770
// check if a backend with higher prio wants to offload the op
769-
if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
771+
if (!sched->disable_op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
770772
for (int b = 0; b < src_backend_id; b++) {
771773
if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
772774
SET_CAUSE(tensor, "1.off");
@@ -1452,7 +1454,8 @@ ggml_backend_sched_t ggml_backend_sched_new(
14521454
ggml_backend_buffer_type_t * bufts,
14531455
int n_backends,
14541456
size_t graph_size,
1455-
bool parallel) {
1457+
bool parallel,
1458+
bool disable_op_offload) {
14561459
GGML_ASSERT(n_backends > 0);
14571460
GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS);
14581461
GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU);
@@ -1497,6 +1500,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
14971500
}
14981501

14991502
sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends);
1503+
sched->disable_op_offload = disable_op_offload;
15001504

15011505
ggml_backend_sched_reset(sched);
15021506

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ extern "C" {
362362
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
363363
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
364364
bool no_perf; // whether to measure performance timings
365+
bool disable_op_offload; // whether to disable offload host tensor operations to device globally
365366
};
366367

367368
// model quantization parameters

src/llama-context.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ llama_context::llama_context(
9393
}
9494

9595
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
96+
cparams.disable_op_offload = params.disable_op_offload;
9697

9798
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
9899

@@ -243,7 +244,7 @@ llama_context::llama_context(
243244
}
244245
}
245246

246-
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
247+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.disable_op_offload));
247248

248249
if (pipeline_parallel) {
249250
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
@@ -1871,6 +1872,7 @@ llama_context_params llama_context_default_params() {
18711872
/*.offload_kqv =*/ true,
18721873
/*.flash_attn =*/ false,
18731874
/*.no_perf =*/ true,
1875+
/*.disable_op_offload =*/ false,
18741876
};
18751877

18761878
return result;

src/llama-cparams.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct llama_cparams {
3030
bool flash_attn;
3131
bool no_perf;
3232
bool warmup;
33+
bool disable_op_offload;
3334

3435
enum llama_pooling_type pooling_type;
3536

tests/test-opt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ int main(void) {
853853
backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
854854

855855
ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
856-
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
856+
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, false);
857857

858858
printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
859859
printf(" Device description: %s\n", ggml_backend_dev_description(devs[i]));

tools/llama-bench/llama-bench.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ struct cmd_params {
219219
std::vector<std::vector<llama_model_tensor_buft_override>> tensor_buft_overrides;
220220
std::vector<bool> use_mmap;
221221
std::vector<bool> embeddings;
222+
std::vector<bool> disable_op_offload;
222223
ggml_numa_strategy numa;
223224
int reps;
224225
ggml_sched_priority prio;
@@ -253,6 +254,7 @@ static const cmd_params cmd_params_defaults = {
253254
/* tensor_buft_overrides*/ { std::vector<llama_model_tensor_buft_override>{{nullptr,nullptr}} },
254255
/* use_mmap */ { true },
255256
/* embeddings */ { false },
257+
/* disable_op_offload */ { false },
256258
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
257259
/* reps */ 5,
258260
/* prio */ GGML_SCHED_PRIO_NORMAL,
@@ -311,6 +313,7 @@ static void print_usage(int /* argc */, char ** argv) {
311313
join(cmd_params_defaults.embeddings, ",").c_str());
312314
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
313315
printf(" -ot --override-tensors <tensor name pattern>=<buffer type>;... (default: disabled)\n");
316+
printf(" -dopo, --disable-op-offload <i> (default: 0)\n");
314317
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
315318
printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio);
316319
printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay);
@@ -588,6 +591,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
588591
}
589592
auto p = string_split<bool>(argv[i], split_delim);
590593
params.embeddings.insert(params.embeddings.end(), p.begin(), p.end());
594+
} else if (arg == "-dopo" || arg == "--disable-op-offload") {
595+
if (++i >= argc) {
596+
invalid_param = true;
597+
break;
598+
}
599+
auto p = string_split<bool>(argv[i], split_delim);
600+
params.disable_op_offload.insert(params.disable_op_offload.end(), p.begin(), p.end());
591601
} else if (arg == "-ts" || arg == "--tensor-split") {
592602
if (++i >= argc) {
593603
invalid_param = true;
@@ -794,6 +804,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
794804
if (params.embeddings.empty()) {
795805
params.embeddings = cmd_params_defaults.embeddings;
796806
}
807+
if (params.disable_op_offload.empty()) {
808+
params.disable_op_offload = cmd_params_defaults.disable_op_offload;
809+
}
797810
if (params.n_threads.empty()) {
798811
params.n_threads = cmd_params_defaults.n_threads;
799812
}
@@ -833,6 +846,7 @@ struct cmd_params_instance {
833846
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
834847
bool use_mmap;
835848
bool embeddings;
849+
bool disable_op_offload;
836850

837851
llama_model_params to_llama_mparams() const {
838852
llama_model_params mparams = llama_model_default_params();
@@ -894,14 +908,15 @@ struct cmd_params_instance {
894908
llama_context_params to_llama_cparams() const {
895909
llama_context_params cparams = llama_context_default_params();
896910

897-
cparams.n_ctx = n_prompt + n_gen + n_depth;
898-
cparams.n_batch = n_batch;
899-
cparams.n_ubatch = n_ubatch;
900-
cparams.type_k = type_k;
901-
cparams.type_v = type_v;
902-
cparams.offload_kqv = !no_kv_offload;
903-
cparams.flash_attn = flash_attn;
904-
cparams.embeddings = embeddings;
911+
cparams.n_ctx = n_prompt + n_gen + n_depth;
912+
cparams.n_batch = n_batch;
913+
cparams.n_ubatch = n_ubatch;
914+
cparams.type_k = type_k;
915+
cparams.type_v = type_v;
916+
cparams.offload_kqv = !no_kv_offload;
917+
cparams.flash_attn = flash_attn;
918+
cparams.embeddings = embeddings;
919+
cparams.disable_op_offload = disable_op_offload;
905920

906921
return cparams;
907922
}
@@ -921,6 +936,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
921936
for (const auto & ot : params.tensor_buft_overrides)
922937
for (const auto & mmp : params.use_mmap)
923938
for (const auto & embd : params.embeddings)
939+
for (const auto & dopo : params.disable_op_offload)
924940
for (const auto & nb : params.n_batch)
925941
for (const auto & nub : params.n_ubatch)
926942
for (const auto & tk : params.type_k)
@@ -959,6 +975,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
959975
/* .tensor_buft_overrides = */ ot,
960976
/* .use_mmap = */ mmp,
961977
/* .embeddings = */ embd,
978+
/* .disable_op_offload= */ dopo,
962979
};
963980
instances.push_back(instance);
964981
}
@@ -990,6 +1007,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
9901007
/* .tensor_buft_overrides = */ ot,
9911008
/* .use_mmap = */ mmp,
9921009
/* .embeddings = */ embd,
1010+
/* .disable_op_offload= */ dopo,
9931011
};
9941012
instances.push_back(instance);
9951013
}
@@ -1021,6 +1039,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
10211039
/* .tensor_buft_overrides = */ ot,
10221040
/* .use_mmap = */ mmp,
10231041
/* .embeddings = */ embd,
1042+
/* .disable_op_offload= */ dopo,
10241043
};
10251044
instances.push_back(instance);
10261045
}
@@ -1056,6 +1075,7 @@ struct test {
10561075
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
10571076
bool use_mmap;
10581077
bool embeddings;
1078+
bool disable_op_offload;
10591079
int n_prompt;
10601080
int n_gen;
10611081
int n_depth;
@@ -1089,6 +1109,7 @@ struct test {
10891109
tensor_buft_overrides = inst.tensor_buft_overrides;
10901110
use_mmap = inst.use_mmap;
10911111
embeddings = inst.embeddings;
1112+
disable_op_offload = inst.disable_op_offload;
10921113
n_prompt = inst.n_prompt;
10931114
n_gen = inst.n_gen;
10941115
n_depth = inst.n_depth;
@@ -1134,7 +1155,7 @@ struct test {
11341155
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
11351156
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
11361157
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
1137-
"use_mmap", "embeddings", "n_prompt", "n_gen", "n_depth", "test_time",
1158+
"use_mmap", "embeddings", "disable_op_offload", "n_prompt", "n_gen", "n_depth", "test_time",
11381159
"avg_ns", "stddev_ns", "avg_ts", "stddev_ts",
11391160
};
11401161
return fields;
@@ -1146,7 +1167,7 @@ struct test {
11461167
if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" ||
11471168
field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" ||
11481169
field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" ||
1149-
field == "avg_ns" || field == "stddev_ns") {
1170+
field == "avg_ns" || field == "stddev_ns" || field == "disable_op_offload") {
11501171
return INT;
11511172
}
11521173
if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
@@ -1222,6 +1243,7 @@ struct test {
12221243
tensor_buft_overrides_str,
12231244
std::to_string(use_mmap),
12241245
std::to_string(embeddings),
1246+
std::to_string(disable_op_offload),
12251247
std::to_string(n_prompt),
12261248
std::to_string(n_gen),
12271249
std::to_string(n_depth),
@@ -1404,6 +1426,9 @@ struct markdown_printer : public printer {
14041426
if (field == "test") {
14051427
return 15;
14061428
}
1429+
if (field == "disable_op_offload") {
1430+
return 4;
1431+
}
14071432

14081433
int width = std::max((int) field.length(), 10);
14091434

@@ -1435,6 +1460,9 @@ struct markdown_printer : public printer {
14351460
if (field == "embeddings") {
14361461
return "embd";
14371462
}
1463+
if (field == "disable_op_offload") {
1464+
return "dopo";
1465+
}
14381466
if (field == "tensor_split") {
14391467
return "ts";
14401468
}
@@ -1503,6 +1531,9 @@ struct markdown_printer : public printer {
15031531
if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
15041532
fields.emplace_back("embeddings");
15051533
}
1534+
if (params.disable_op_offload.size() > 1 || params.disable_op_offload != cmd_params_defaults.disable_op_offload) {
1535+
fields.emplace_back("disable_op_offload");
1536+
}
15061537
fields.emplace_back("test");
15071538
fields.emplace_back("t/s");
15081539

0 commit comments

Comments
 (0)