Skip to content

Commit f11f6d6

Browse files
lksj92hsIlko Tsenov
authored andcommitted
Allow layer groups in --n-cpu-moe
Resolved merge conflicts
1 parent 8ff2060 commit f11f6d6

File tree

1 file changed

+44
-7
lines changed

1 file changed

+44
-7
lines changed

common/arg.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2552,18 +2552,55 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
25522552
}
25532553
).set_env("LLAMA_ARG_CPU_MOE"));
25542554
add_opt(common_arg(
2555-
{"--n-cpu-moe", "-ncmoe"}, "N",
2556-
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU",
2557-
[](common_params & params, int value) {
2558-
if (value < 0) {
2559-
throw std::invalid_argument("invalid value");
2555+
{"--n-cpu-moe", "-ncmoe"}, "N|N1-N2|N1,N2,...",
2556+
"keep the Mixture of Experts (MoE) weights of specified layers in the CPU\n"
2557+
"- N: keep layers 0 to N-1 in CPU\n"
2558+
"- N1-N2: keep layers N1 to N2 (inclusive) in CPU\n"
2559+
"- N1,N2,...: keep specific layers in CPU (comma-separated)",
2560+
[](common_params & params, const std::string & value) {
2561+
std::vector<int> layers_to_override;
2562+
2563+
std::stringstream ss(value);
2564+
std::string item;
2565+
while (std::getline(ss, item, ',')) {
2566+
if (item.find('-') != std::string::npos) {
2567+
// Range: N1-N2
2568+
size_t dash_pos = item.find('-');
2569+
int start = std::stoi(item.substr(0, dash_pos));
2570+
int end = std::stoi(item.substr(dash_pos + 1));
2571+
if (start < 0 || end < 0 || start > end) {
2572+
throw std::invalid_argument("invalid range");
2573+
}
2574+
for (int i = start; i <= end; ++i) {
2575+
layers_to_override.push_back(i);
2576+
}
2577+
} else {
2578+
int n = std::stoi(item);
2579+
if (n < 0) {
2580+
throw std::invalid_argument("invalid value");
2581+
}
2582+
// Single value: treat as range 0 to N-1
2583+
if (value.find(',') == std::string::npos) {
2584+
for (int i = 0; i < n; ++i) {
2585+
layers_to_override.push_back(i);
2586+
}
2587+
} else {
2588+
// Value in a list: specific layer index
2589+
layers_to_override.push_back(n);
2590+
}
2591+
}
25602592
}
2561-
for (int i = 0; i < value; ++i) {
2593+
2594+
for (int layer_idx : layers_to_override) {
25622595
// keep strings alive and avoid leaking memory by storing them in a static vector
25632596
static std::list<std::string> buft_overrides;
2564-
buft_overrides.push_back(llm_ffn_exps_block_regex(i));
2597+
buft_overrides.push_back(llm_ffn_exps_block_regex(layer_idx));
25652598
params.tensor_buft_overrides.push_back({buft_overrides.back().c_str(), ggml_backend_cpu_buffer_type()});
25662599
}
2600+
2601+
if (!layers_to_override.empty()) {
2602+
LOG_INF("args: --n-cpu-moe overriding %zu layers\n", layers_to_override.size());
2603+
}
25672604
}
25682605
).set_env("LLAMA_ARG_N_CPU_MOE"));
25692606
add_opt(common_arg(

0 commit comments

Comments
 (0)