@@ -73,6 +73,7 @@ struct Args
7373 int save_session_rounds = -1 ;
7474 int beam_size = -1 ;
7575 int log_level = 4 ;
76+ bool moe_on_cpu = false ;
7677};
7778
7879#define MULTI_LINE_END_MARKER_W L" \\ ."
@@ -125,6 +126,7 @@ void usage(const std::string &prog)
125126 << " Performance options:\n "
126127 << " -n, --threads N number of threads for inference (default: number of cores)\n "
127128 << " -ngl, --n_gpu_layers N number of model layers to offload to each GPU (default: GPU not used)\n "
129+ << " +moe_on_cpu alway use CPU for sparse operations (MoE) (default: off)\n "
128130 << " Sampling options:\n "
129131 << " --sampling ALG sampling algorithm (ALG = greedy | top_p | tfs) (default: top_p) \n "
130132 << " where, tfs = Tail Free Sampling\n "
@@ -232,6 +234,12 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
232234 args.field .push_back (f (argv[c].c_str ())); \
233235 }
234236
237+ #define handle_flag (field ) \
238+ else if ((strcmp (arg, " +" #field) == 0 )) \
239+ { \
240+ args.field = true ; \
241+ }
242+
235243 size_t c = 1 ;
236244
237245 try
@@ -271,14 +279,9 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
271279 {
272280 args.reversed_role = true ;
273281 }
274- else if (strcmp (arg, " +rag_dump" ) == 0 )
275- {
276- args.rag_dump = true ;
277- }
278- else if (strcmp (arg, " +rerank_rewrite" ) == 0 )
279- {
280- args.rerank_rewrite = true ;
281- }
282+ handle_flag (rag_dump)
283+ handle_flag (rerank_rewrite)
284+ handle_flag (moe_on_cpu)
282285 else if (strcmp (arg, " --format" ) == 0 )
283286 {
284287 c++;
@@ -655,6 +658,9 @@ static void run_qa_ranker(Args &args, chatllm::Pipeline &pipeline, TextStreamer
655658 gen_config.set_ai_prefix(args.ai_prefix); gen_config.dump_dot = args.dump_dot; \
656659 gen_config.emb_rank_query_sep = args.emb_rank_query_sep;
657660
661+ #define DEF_ExtraArgs (pipe_args, args ) \
662+ chatllm::ModelObject::extra_args pipe_args (args.max_length, args.layer_spec, args.n_gpu_layers, args.moe_on_cpu)
663+
658664chatllm::BaseStreamer *get_streamer_for_log(void );
659665
660666void log_internal (int level, const char * text)
@@ -1003,7 +1009,7 @@ int main(int argc, const char **argv)
10031009
10041010 try
10051011 {
1006- chatllm::ModelObject::extra_args pipe_args (args. max_length , args. layer_spec , args. n_gpu_layers );
1012+ DEF_ExtraArgs (pipe_args , args);
10071013 TextStreamer streamer (nullptr );
10081014 streamer.log_level = args.log_level ;
10091015 log_streamer = &streamer;
@@ -1240,7 +1246,7 @@ int chatllm_start(struct chatllm_obj *obj, f_chatllm_print f_print, f_chatllm_en
12401246
12411247 try
12421248 {
1243- chatllm::ModelObject::extra_args pipe_args (args. max_length , args. layer_spec , args. n_gpu_layers );
1249+ DEF_ExtraArgs (pipe_args , args);
12441250
12451251 if ((args.embedding_model_path .size () < 1 ) || (args.vector_stores .empty ()))
12461252 {
0 commit comments