Skip to content

Commit 7fd47f1

Browse files
author
Judd
committed
MoE CPU offloading
1 parent d0431e2 commit 7fd47f1

File tree

10 files changed

+110
-38
lines changed

10 files changed

+110
-38
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16-
* 2024-07-14: [ggml updated](https://github.com/ggml-org/llama.cpp/tree/0f2bbe656473177538956d22b6842bcaa0449fab) again
16+
* 2025-02-19: MoE CPU offloading
17+
* 2025-02-17: [ggml updated](https://github.com/ggml-org/llama.cpp/tree/0f2bbe656473177538956d22b6842bcaa0449fab) again
1718
* 2025-02-10: [GPU acceleration](./docs/gpu.md) 🔥
1819
* 2025-01-25: MiniCPM Embedding & ReRanker
1920
* 2025-01-21: DeepSeek-R1-Distill-Llama & Qwen
@@ -32,7 +33,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
3233

3334
## Features
3435

35-
* [x] Accelerated memory-efficient CPU inference with int4/int8 quantization, optimized KV cache and parallel computing;
36+
* [x] Accelerated memory-efficient CPU/GPU inference with int4/int8 quantization, optimized KV cache and parallel computing;
3637
* [x] Use OOP to address the similarities between different _Transformer_ based models;
3738
* [x] Streaming generation with typewriter effect;
3839
* [x] Continuous chatting (content length is virtually unlimited)

README_ja.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
## 特徴
1414

15-
* [x] int4/int8 量子化、最適化された KV キャッシュ、並列計算によるメモリ効率の高い CPU 推論の加速
15+
* [x] int4/int8 量子化、最適化された KV キャッシュ、並列計算によるメモリ効率の高い CPU/GPU 推論の加速
1616
* [x] OOP を使用して、異なる _Transformer_ ベースのモデル間の類似性に対処
1717
* [x] タイプライター効果を伴うストリーミング生成
1818
* [x] 継続的なチャット(コンテンツの長さは事実上無制限)

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
## 特点
1414

15-
- [x] 内存高效、加速 CPU 推理:使用 int4/int8 量化、优化的 KV 缓存和并行计算。
15+
- [x] 内存高效、加速 CPU/GPU 推理:使用 int4/int8 量化、优化的 KV 缓存和并行计算。
1616
- [x] 面向对象编程:关注基于 Transformer 的模型之间的相似性。
1717
- [x] 流式生成:打字机效果。
1818
- [x] 连续聊天:内容长度几乎无限。

src/backend.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ namespace chatllm
119119
total[usage] += size;
120120
ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer(get_allocator(usage), size);
121121

122-
CHATLLM_CHECK(buf) << __FUNCTION__ << "() failed to allocate buffer";
122+
CHATLLM_CHECK(buf) << __FUNCTION__ << "() failed to allocate buffer of size " << size;
123123

124124
auto r = new BackendBuffer(buf);
125125
buffers.emplace_back(r);
@@ -261,6 +261,11 @@ namespace chatllm
261261
alloc_of_tensor.insert_or_assign(tensor, allocator);
262262
}
263263

264+
void LayerAllocatorManager::override_to_cpu_only(bool flag)
265+
{
266+
cpu_override = flag;
267+
}
268+
264269
int LayerAllocatorManager::get_mapped_layer_id(int layer_id)
265270
{
266271
int id = layer_id;
@@ -276,7 +281,7 @@ namespace chatllm
276281
default:
277282
break;
278283
}
279-
if ((id < 0) || (id >= (int)allocators.size()))
284+
if (cpu_override || (id < 0) || (id >= (int)allocators.size()))
280285
id = (int)allocators.size() - 1;
281286

282287
return id;
@@ -734,6 +739,11 @@ namespace chatllm
734739
backend_context->layer_allocators.move_to_layer(layer_id);
735740
}
736741

742+
void ComputeContext::backend_cpu_override(bool flag)
743+
{
744+
backend_context->layer_allocators.override_to_cpu_only(flag);
745+
}
746+
737747
BackendBufAllocator *ComputeContext::get_allocator(void)
738748
{
739749
return backend_context->layer_allocators.get_allocator();

src/backend.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ namespace chatllm
166166

167167
void register_tensor_allocator(ggml::tensor *tensor, LayerBufAllocator *allocator);
168168

169+
void override_to_cpu_only(bool flag);
170+
169171
protected:
170172
int get_mapped_layer_id(int layer_id);
171173
public:
@@ -175,6 +177,7 @@ namespace chatllm
175177
int epilog_layer_backend_map_to_layer_id = -1;
176178
int cur_layer = MiscLayer::Prolog;
177179
std::map<ggml::tensor *, LayerBufAllocator *> alloc_of_tensor;
180+
bool cpu_override = false;
178181
};
179182

180183
class ComputeManager
@@ -319,6 +322,12 @@ namespace chatllm
319322
class ComputeContext
320323
{
321324
public:
325+
// additional user options
326+
struct UserOptions
327+
{
328+
bool moe_on_cpu = false;
329+
};
330+
322331
ComputeContext(BackendContext *backend_context);
323332

324333
virtual struct ggml_context *get_ctx() = 0;
@@ -328,6 +337,7 @@ namespace chatllm
328337
virtual void cb_op_tensor(ggml::tensor *tensor);
329338

330339
virtual void move_to_layer(int layer_id);
340+
virtual void backend_cpu_override(bool flag);
331341

332342
BackendBufAllocator *get_allocator(void);
333343
BackendBufAllocator *get_allocator(ggml::tensor *tensor);
@@ -352,6 +362,9 @@ namespace chatllm
352362

353363
BackendContext *get_backend_context(void) { return backend_context; }
354364

365+
public:
366+
UserOptions user_options;
367+
355368
protected:
356369
virtual ggml_backend_sched_t get_sched(void);
357370

src/chat.h

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,51 @@ namespace chatllm
280280
ggml::type dtype;
281281
};
282282

283+
class LayerMover
284+
{
285+
public:
286+
LayerMover(InitContext *ctx, int layer_id): ctx(ctx)
287+
{
288+
ctx->move_to_layer(layer_id);
289+
}
290+
291+
operator InitContext *() const
292+
{
293+
return ctx;
294+
}
295+
private:
296+
InitContext *ctx;
297+
};
298+
299+
class CPUMover
300+
{
301+
public:
302+
CPUMover(ComputeContext *ctx, bool activated): ctx(ctx), activated(activated)
303+
{
304+
if (activated)
305+
ctx->backend_cpu_override(true);
306+
}
307+
308+
~CPUMover()
309+
{
310+
if (activated)
311+
ctx->backend_cpu_override(false);
312+
}
313+
314+
operator InitContext *() const
315+
{
316+
return dynamic_cast<InitContext *>(ctx);
317+
}
318+
319+
operator ComputeContext *() const
320+
{
321+
return ctx;
322+
}
323+
private:
324+
ComputeContext *ctx;
325+
const bool activated;
326+
};
327+
283328
class ChunkInterceptor;
284329

285330
class BaseStreamer
@@ -844,10 +889,11 @@ namespace chatllm
844889
int max_length;
845890
std::string layer_spec;
846891
std::string gpu_layers;
847-
extra_args(int max_length, const std::string &layer_spec, const std::string &gpu_layers)
848-
: max_length(max_length), layer_spec(layer_spec), gpu_layers(gpu_layers)
892+
bool moe_on_cpu;
893+
extra_args(int max_length, const std::string &layer_spec, const std::string &gpu_layers, bool moe_on_cpu)
894+
: max_length(max_length), layer_spec(layer_spec), gpu_layers(gpu_layers), moe_on_cpu(moe_on_cpu)
849895
{}
850-
extra_args() : extra_args(-1, "", "") {}
896+
extra_args() : extra_args(-1, "", "", false) {}
851897
};
852898

853899
ModelObject(const std::string &path);

src/layers.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,8 @@ namespace chatllm
14501450

14511451
ggml::tensor * logits = gate.forward(ctx, hidden_states); // [qlen, num_experts]
14521452

1453+
CPUMover mover(ctx, ctx->user_options.moe_on_cpu);
1454+
14531455
ggml::tensor * probs = ggml::soft_max(ctx, logits); // [qlen, num_experts]
14541456

14551457
// select experts

src/layers.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1781,14 +1781,18 @@ namespace chatllm
17811781
BaseSparseMLP() = default;
17821782
BaseSparseMLP(InitContext *ctx, int hidden_size, int intermediate_size, int num_local_experts, int num_experts_per_tok,
17831783
ActFunc act, bool gate_use_bias)
1784-
: num_local_experts(num_local_experts), num_experts_per_tok(num_experts_per_tok),
1784+
:
1785+
num_local_experts(num_local_experts), num_experts_per_tok(num_experts_per_tok),
17851786
gate(ctx, hidden_size, num_local_experts, gate_use_bias),
1787+
mover(new CPUMover(ctx, ctx->user_options.moe_on_cpu)),
17861788
experts_gate(ctx, hidden_size, intermediate_size, num_local_experts),
17871789
experts_down(ctx, intermediate_size, hidden_size, num_local_experts),
17881790
experts_up (ctx, hidden_size, intermediate_size, num_local_experts),
17891791
act(act),
17901792
norm_topk_prob(true)
17911793
{
1794+
delete mover;
1795+
mover = nullptr;
17921796
}
17931797

17941798
using Block::forward;
@@ -1819,6 +1823,7 @@ namespace chatllm
18191823
const int num_local_experts;
18201824
const int num_experts_per_tok;
18211825
Linear gate;
1826+
CPUMover *mover; // when `+moe_on_cpu` is set, all things are done on CPU except for `gate`
18221827
MultiLinear experts_gate;
18231828
MultiLinear experts_down;
18241829
MultiLinear experts_up;

src/main.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct Args
7373
int save_session_rounds = -1;
7474
int beam_size = -1;
7575
int log_level = 4;
76+
bool moe_on_cpu = false;
7677
};
7778

7879
#define MULTI_LINE_END_MARKER_W L"\\."
@@ -125,6 +126,7 @@ void usage(const std::string &prog)
125126
<< "Performance options:\n"
126127
<< " -n, --threads N number of threads for inference (default: number of cores)\n"
127128
<< " -ngl, --n_gpu_layers N number of model layers to offload to each GPU (default: GPU not used)\n"
129+
<< " +moe_on_cpu alway use CPU for sparse operations (MoE) (default: off)\n"
128130
<< "Sampling options:\n"
129131
<< " --sampling ALG sampling algorithm (ALG = greedy | top_p | tfs) (default: top_p) \n"
130132
<< " where, tfs = Tail Free Sampling\n"
@@ -232,6 +234,12 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
232234
args.field.push_back(f(argv[c].c_str())); \
233235
}
234236

237+
#define handle_flag(field) \
238+
else if ((strcmp(arg, "+" #field) == 0)) \
239+
{ \
240+
args.field = true; \
241+
}
242+
235243
size_t c = 1;
236244

237245
try
@@ -271,14 +279,9 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
271279
{
272280
args.reversed_role = true;
273281
}
274-
else if (strcmp(arg, "+rag_dump") == 0)
275-
{
276-
args.rag_dump = true;
277-
}
278-
else if (strcmp(arg, "+rerank_rewrite") == 0)
279-
{
280-
args.rerank_rewrite = true;
281-
}
282+
handle_flag(rag_dump)
283+
handle_flag(rerank_rewrite)
284+
handle_flag(moe_on_cpu)
282285
else if (strcmp(arg, "--format") == 0)
283286
{
284287
c++;
@@ -655,6 +658,9 @@ static void run_qa_ranker(Args &args, chatllm::Pipeline &pipeline, TextStreamer
655658
gen_config.set_ai_prefix(args.ai_prefix); gen_config.dump_dot = args.dump_dot; \
656659
gen_config.emb_rank_query_sep = args.emb_rank_query_sep;
657660

661+
#define DEF_ExtraArgs(pipe_args, args) \
662+
chatllm::ModelObject::extra_args pipe_args(args.max_length, args.layer_spec, args.n_gpu_layers, args.moe_on_cpu)
663+
658664
chatllm::BaseStreamer *get_streamer_for_log(void);
659665

660666
void log_internal(int level, const char * text)
@@ -1003,7 +1009,7 @@ int main(int argc, const char **argv)
10031009

10041010
try
10051011
{
1006-
chatllm::ModelObject::extra_args pipe_args(args.max_length, args.layer_spec, args.n_gpu_layers);
1012+
DEF_ExtraArgs(pipe_args, args);
10071013
TextStreamer streamer(nullptr);
10081014
streamer.log_level = args.log_level;
10091015
log_streamer = &streamer;
@@ -1240,7 +1246,7 @@ int chatllm_start(struct chatllm_obj *obj, f_chatllm_print f_print, f_chatllm_en
12401246

12411247
try
12421248
{
1243-
chatllm::ModelObject::extra_args pipe_args(args.max_length, args.layer_spec, args.n_gpu_layers);
1249+
DEF_ExtraArgs(pipe_args, args);
12441250

12451251
if ((args.embedding_model_path.size() < 1) || (args.vector_stores.empty()))
12461252
{

src/models.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ namespace chatllm
2626
struct RuntimeConfig
2727
{
2828
std::string gpu_layers;
29-
RuntimeConfig(const std::string &gpu_layers): gpu_layers(gpu_layers) {}
29+
bool moe_on_cpu;
30+
RuntimeConfig(const std::string &gpu_layers, bool moe_on_cpu):
31+
gpu_layers(gpu_layers), moe_on_cpu(moe_on_cpu)
32+
{}
3033
};
3134

3235
class ForwardContext : public ComputeContext
@@ -44,22 +47,6 @@ namespace chatllm
4447
ggml_cgraph *gf;
4548
};
4649

47-
class LayerMover
48-
{
49-
public:
50-
LayerMover(InitContext *ctx, int layer_id): ctx(ctx)
51-
{
52-
ctx->move_to_layer(layer_id);
53-
}
54-
55-
operator InitContext *() const
56-
{
57-
return ctx;
58-
}
59-
private:
60-
InitContext *ctx;
61-
};
62-
6350
static ForwardContext *dbg_ctx = nullptr;
6451
static std::unordered_map<ggml::tensor *, std::string> inspected_set;
6552
static ggml::tensor *dbg_w = nullptr;
@@ -1151,6 +1138,7 @@ namespace chatllm
11511138
{
11521139
std::vector<BackendContext::gpu_cfg> gpu_cfgs;
11531140
parse_gpu_layers(gpu_cfgs, rt_config.gpu_layers);
1141+
w_ctx_.user_options.moe_on_cpu = rt_config.moe_on_cpu;
11541142
backend_context.init(gpu_cfgs, config_.num_hidden_layers, GRAPH_SIZE);
11551143
}
11561144

@@ -1209,6 +1197,7 @@ namespace chatllm
12091197
}
12101198

12111199
ForwardContext ctx(&backend_context);
1200+
ctx.user_options = w_ctx_.user_options;
12121201

12131202
ctx.gctx = GGMLContext({.mem_size = backend_context.buf_compute_meta.size(), .mem_buffer = backend_context.buf_compute_meta.data(), .no_alloc = true});
12141203
ctx.gf = ggml::new_graph_custom(&ctx, GRAPH_SIZE, false);
@@ -1933,7 +1922,7 @@ namespace chatllm
19331922
config.num_hidden_layers = (int)layers.size();
19341923
}
19351924

1936-
RuntimeConfig rt_config(args.gpu_layers);
1925+
RuntimeConfig rt_config(args.gpu_layers, args.moe_on_cpu);
19371926

19381927
// load model
19391928
ConditionalGeneration *model = new ConditionalGeneration(config, rt_config);

0 commit comments

Comments
 (0)