Skip to content

Commit c6615da

Browse files
author
Judd
committed
new option emb_rank_query_sep; minor updates to RAG
1 parent 5e2d499 commit c6615da

File tree

4 files changed

+43
-2
lines changed

4 files changed

+43
-2
lines changed

bindings/libchatllm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ enum PrintType
3838
PRINTLN_EMBEDDING = 8, // print a whole line: embedding (example: "0.1, 0.3, ...")
3939
PRINTLN_RANKING = 9, // print a whole line: ranking (example: "0.8")
4040
PRINTLN_TOKEN_IDS =10, // print a whole line: token ids (example: "1, 3, 5, 8, ...")
41+
PRINTLN_LOGGING =11, // print a whole line: internal logging with the first char indicating level
42+
// (space): None; D: Debug; I: Info; W: Warn; E: Error; .: continue
4143

4244
PRINT_EVT_ASYNC_COMPLETED = 100, // last async operation completed (utf8_str is null)
4345
};

src/chat.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,16 @@ namespace chatllm
15041504
rewritten_query = query;
15051505
}
15061506

1507+
if (!modelobj.loaded && (gen_config.emb_rank_query_sep.size() > 0))
1508+
{
1509+
auto pos = query.find(gen_config.emb_rank_query_sep);
1510+
if (pos != std::string::npos)
1511+
{
1512+
rewritten_query = query.substr(0, pos);
1513+
query.erase(0, pos + gen_config.emb_rank_query_sep.size());
1514+
}
1515+
}
1516+
15071517
embedding.text_embedding(rewritten_query, gen_config, query_emb);
15081518

15091519
vs.get()->Query(query_emb, selected, retrieve_top_n);
@@ -1560,6 +1570,9 @@ namespace chatllm
15601570

15611571
auto composed = composer.compose_augmented_query(query, augments);
15621572

1573+
if (!modelobj.loaded)
1574+
streamer->put_chunk(true, composed);
1575+
15631576
history[index].content = composed;
15641577
}
15651578

src/chat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ namespace chatllm
292292
EMBEDDING = 8,
293293
RANKING = 9,
294294
TOKEN_IDS =10,
295+
LOGGING =11,
295296
};
296297
BaseStreamer(BaseTokenizer *tokenizer);
297298
virtual ~BaseStreamer() = default;
@@ -503,6 +504,7 @@ namespace chatllm
503504
std::string sampling;
504505
std::string ai_prefix;
505506
std::string dump_dot;
507+
std::string emb_rank_query_sep;
506508

507509
GenerationConfig()
508510
{

src/main.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct Args
4343
std::string n_gpu_layers;
4444
std::string cur_vs_name = "default";
4545
std::string dump_dot;
46+
std::string emb_rank_query_sep;
4647
std::map<std::string, std::vector<std::string>> vector_stores;
4748
int max_length = -1;
4849
int max_context_length = 512;
@@ -150,6 +151,8 @@ void usage(const std::string &prog)
150151
<< " items with a lower score are discarded.\n"
151152
<< " --rerank_top_n N number of selected items using reranker model (default: 1)\n"
152153
<< " +rerank_rewrite reranker use the rewritten query (default: OFF, i.e. use the original user input)\n"
154+
<< " --emb_rank_query_sep separator for embedding & rerank query (default: \"\", i.e. disabled)\n"
155+
<< " only used without main model\n"
153156
<< " --hide_reference do not show references (default: false)\n"
154157
<< " --rag_template ... prompt template for RAG (macros: {context}, {question}) (optional).\n"
155158
<< " Support some C escape sequences (\\n). Example:\n"
@@ -231,7 +234,7 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
231234
while (c < argc)
232235
{
233236
const char *arg = argv[c].c_str();
234-
if ((strcmp(arg, "--help") == 0) || (strcmp(arg, "-h") == 0))
237+
if ((strcmp(arg, "--help") == 0) || (strcmp(arg, "-h") == 0) || (strcmp(arg, "-?") == 0))
235238
{
236239
args.show_help = true;
237240
}
@@ -345,6 +348,7 @@ static size_t parse_args(Args &args, const std::vector<std::string> &argv)
345348
handle_para0("--rag_post_extending", rag_post_extending, std::stoi)
346349
handle_para0("--rag_template", rag_template, std::string)
347350
handle_para0("--rag_context_sep", rag_context_sep, std::string)
351+
handle_para0("--emb_rank_query_sep", emb_rank_query_sep, std::string)
348352
handle_para0("--init_vs", vector_store_in, std::string)
349353
handle_para0("--merge_vs", merge_vs, std::string)
350354
handle_para0("--layer_spec", layer_spec, std::string)
@@ -641,10 +645,28 @@ static void run_qa_ranker(Args &args, chatllm::Pipeline &pipeline, TextStreamer
641645

642646
#define DEF_GenerationConfig(gen_config, args) chatllm::GenerationConfig gen_config(args.max_length, args.max_context_length, args.temp > 0, args.reversed_role, \
643647
args.top_k, args.top_p, args.temp, args.num_threads, args.sampling, args.presence_penalty, args.tfs_z); \
644-
gen_config.set_ai_prefix(args.ai_prefix); gen_config.dump_dot = args.dump_dot;
648+
gen_config.set_ai_prefix(args.ai_prefix); gen_config.dump_dot = args.dump_dot; \
649+
gen_config.emb_rank_query_sep = args.emb_rank_query_sep;
650+
651+
static void _ggml_log_callback(enum ggml_log_level level, const char * text, void * user_data)
652+
{
653+
chatllm::BaseStreamer *streamer = (chatllm::BaseStreamer *)user_data;
654+
std::ostringstream oss;
655+
static const char tags[] = {' ', 'D', 'I', 'W', 'E', '.'};
656+
657+
if ((0 <= level) && (level < sizeof(tags)))
658+
oss << tags[level];
659+
else
660+
oss << '?';
661+
662+
oss << text;
663+
streamer->putln(oss.str(), chatllm::BaseStreamer::LOGGING);
664+
}
645665

646666
void chat(Args &args, chatllm::Pipeline &pipeline, TextStreamer &streamer)
647667
{
668+
ggml_log_set(_ggml_log_callback, &streamer);
669+
648670
if (args.system.size() > 0)
649671
pipeline.set_system_prompt(args.system);
650672

@@ -1092,6 +1114,8 @@ static int start_chat(Chat *chat, Args &args, chatllm::Pipeline &pipeline, chatl
10921114
chat->pipeline = &pipeline;
10931115
chat->streamer = &streamer;
10941116

1117+
ggml_log_set(_ggml_log_callback, &streamer);
1118+
10951119
if (args.system.size() > 0)
10961120
pipeline.set_system_prompt(args.system);
10971121

0 commit comments

Comments
 (0)