Skip to content

Commit c87b196

Browse files
committed
feat(amx): add --amx toggle; prefer CPU 'extra' with GPU host+mmap when enabled
- CLI/server/bench: --amx (presence=enabled) -> mparams.amx_enable_mmap - Loader: with mmap + GPU host buft, prefer CPU 'extra' if supported (AMX repack), else fallback - llama-bench: add --amx flag to match CLI/server behavior
1 parent 835b2b9 commit c87b196

File tree

6 files changed

+149
-64
lines changed

6 files changed

+149
-64
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,6 +2538,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25382538
params.use_mmap = false;
25392539
}
25402540
).set_env("LLAMA_ARG_NO_MMAP"));
2541+
add_opt(common_arg(
2542+
{"--amx"},
2543+
"enable AMX-aware CPU repack when mmap is on and a GPU host buffer would be used; prefers CPU \"extra\" buffer types (e.g., AMX) for weights on CPU.",
2544+
[](common_params & params) {
2545+
params.amx_enable_mmap = true;
2546+
}
2547+
));
2548+
25412549
add_opt(common_arg(
25422550
{"--numa"}, "TYPE",
25432551
"attempt optimizations that help on some NUMA systems\n"

common/common.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,28 +1109,42 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11091109
mparams.n_gpu_layers = params.n_gpu_layers;
11101110
}
11111111

1112-
mparams.main_gpu = params.main_gpu;
1113-
mparams.split_mode = params.split_mode;
1114-
mparams.tensor_split = params.tensor_split;
1115-
mparams.use_mmap = params.use_mmap;
1116-
mparams.use_mlock = params.use_mlock;
1117-
mparams.check_tensors = params.check_tensors;
1112+
mparams.main_gpu = params.main_gpu;
1113+
mparams.split_mode = params.split_mode;
1114+
1115+
// NOTE: common_params::tensor_split is a C-array (float [LLAMA_MAX_DEVICES])
1116+
// Upstream expects a pointer to the first element – do NOT use .data().
1117+
mparams.tensor_split = params.tensor_split;
1118+
1119+
mparams.use_mmap = params.use_mmap;
1120+
mparams.use_mlock = params.use_mlock;
1121+
mparams.check_tensors = params.check_tensors;
1122+
1123+
// Keep upstream policy: disable extra buffer types when --no-extra-bufts is set
11181124
mparams.use_extra_bufts = !params.no_extra_bufts;
11191125

1126+
// NEW: forward the AMX toggle from CLI into model params
1127+
mparams.amx_enable_mmap = params.amx_enable_mmap;
1128+
1129+
// Preserve upstream sentinel handling for KV overrides
11201130
if (params.kv_overrides.empty()) {
11211131
mparams.kv_overrides = NULL;
11221132
} else {
1123-
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
1133+
GGML_ASSERT(params.kv_overrides.back().key[0] == 0 &&
1134+
"KV overrides not terminated with empty key");
11241135
mparams.kv_overrides = params.kv_overrides.data();
11251136
}
11261137

1138+
// Preserve upstream sentinel handling for tensor buffer overrides
11271139
if (params.tensor_buft_overrides.empty()) {
11281140
mparams.tensor_buft_overrides = NULL;
11291141
} else {
1130-
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
1142+
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr &&
1143+
"Tensor buffer overrides not terminated with empty pattern");
11311144
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
11321145
}
11331146

1147+
// Keep upstream progress callback wiring
11341148
mparams.progress_callback = params.load_progress_callback;
11351149
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
11361150

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,8 @@ struct common_params {
392392
bool check_tensors = false; // validate tensor data
393393
bool no_op_offload = false; // globally disable offload host tensor operations to device
394394
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
395+
bool amx_enable_mmap = false; // prefer CPU "extra" buffers when GPU host+mmap is chosen (enable AMX)
396+
395397

396398
bool single_turn = false; // single turn chat conversation
397399

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ extern "C" {
296296
bool use_mlock; // force system to keep model in RAM
297297
bool check_tensors; // validate model tensor data
298298
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
299+
bool amx_enable_mmap; // prefer CPU 'extra' buffers with GPU host+mmap (enable AMX repack on CPU)
299300
};
300301

301302
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

src/llama-model.cpp

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2288,24 +2288,56 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
22882288
}
22892289
}
22902290

2291-
// avoid using a host buffer when using mmap
2292-
auto * buft_dev = ggml_backend_buft_get_device(buft);
2293-
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2294-
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2295-
if (!cpu_dev) {
2296-
throw std::runtime_error("no CPU backend found");
2297-
}
2298-
buft = ggml_backend_dev_buffer_type(cpu_dev);
2291+
// avoid using a host buffer when using mmap
2292+
auto * buft_dev = ggml_backend_buft_get_device(buft);
2293+
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2294+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2295+
if (!cpu_dev) {
2296+
throw std::runtime_error("no CPU backend found");
2297+
}
2298+
2299+
// If enabled, prefer CPU "extra" (AMX) buffer types for weights on CPU; else use CPU default
2300+
ggml_backend_buffer_type_t cpu_default_buft = ggml_backend_dev_buffer_type(cpu_dev);
2301+
const bool prefer_cpu_extra = params.amx_enable_mmap;
2302+
2303+
if (!prefer_cpu_extra) {
2304+
buft = cpu_default_buft;
2305+
} else {
2306+
ggml_backend_buffer_type_t chosen = nullptr;
2307+
2308+
// Iterate available buffer types, skipping device-host buffer types
2309+
for (const auto & cur : *buft_list) {
2310+
ggml_backend_dev_t cur_dev = cur.first;
2311+
ggml_backend_buffer_type_t cur_buft = cur.second;
2312+
2313+
if (cur_dev && cur_buft == ggml_backend_dev_host_buffer_type(cur_dev)) {
2314+
continue;
22992315
}
23002316

2301-
if (buft != buft_list->front().second) {
2302-
n_moved_tensors++;
2303-
if (!first_moved_tensor) {
2304-
first_moved_tensor = t_meta;
2305-
first_moved_from_buft = buft_list->front().second;
2306-
first_moved_to_buft = buft;
2317+
// Prefer CPU "extra" (non-default) if supported for this tensor/op
2318+
if (cur_dev == cpu_dev && cur_buft != cpu_default_buft) {
2319+
if (weight_buft_supported(hparams, t_meta, op, cur_buft, cur_dev)) {
2320+
chosen = cur_buft;
2321+
break;
23072322
}
23082323
}
2324+
}
2325+
2326+
buft = chosen ? chosen : cpu_default_buft;
2327+
}
2328+
}
2329+
2330+
2331+
// (keep your existing moved-tensors accounting exactly as-is)
2332+
if (buft != buft_list->front().second) {
2333+
n_moved_tensors++;
2334+
if (!first_moved_tensor) {
2335+
first_moved_tensor = t_meta;
2336+
first_moved_from_buft = buft_list->front().second;
2337+
first_moved_to_buft = buft;
2338+
}
2339+
}
2340+
23092341

23102342
ggml_context * ctx = ctx_for_buft(buft);
23112343

@@ -19642,6 +19674,7 @@ llama_model_params llama_model_default_params() {
1964219674
/*.use_mlock =*/ false,
1964319675
/*.check_tensors =*/ false,
1964419676
/*.use_extra_bufts =*/ true,
19677+
/*.amx_enable_mmap =*/ false,
1964519678
};
1964619679

1964719680
return result;

0 commit comments

Comments
 (0)