@@ -260,43 +260,47 @@ static size_t validate_utf8(const std::string& text) {
260260// template utils
261261//
262262
263- // format rerank task:
263+ // format and tokenize rerank task:
264264// - using SEP token: [BOS]query[EOS][SEP]doc[EOS]
265265// - using prompt: <rerank_prefix>query<rerank_suffix>doc
266- static llama_tokens format_rerank (const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc ) {
266+ static std::vector< llama_tokens> tokenize_rerank (const struct llama_model * model, const std::string & query, const std::vector<std::string> & documents ) {
267267 const llama_vocab * vocab = llama_model_get_vocab (model);
268- llama_tokens result;
268+ std::vector< llama_tokens> result;
269269
270- if (llama_vocab_sep (vocab) != LLAMA_TOKEN_NULL) {
271- // Get EOS token - use SEP token as fallback if EOS is not available
272- llama_token eos_token = llama_vocab_eos (vocab);
273- if (eos_token == LLAMA_TOKEN_NULL) {
274- eos_token = llama_vocab_sep (vocab);
275- }
270+ for (const auto & doc : documents) {
271+ if (llama_vocab_sep (vocab) != LLAMA_TOKEN_NULL) {
272+ // Get EOS token - use SEP token as fallback if EOS is not available
273+ llama_tokens tok;
274+ llama_tokens tok_query = common_tokenize (vocab, query, false , false );
275+ llama_tokens tok_doc = common_tokenize (vocab, doc, false , false );
276+ llama_token eos_token = llama_vocab_eos (vocab);
277+ if (eos_token == LLAMA_TOKEN_NULL) {
278+ eos_token = llama_vocab_sep (vocab);
279+ }
276280
277- result.reserve (doc.size () + query.size () + 4 );
278- result.push_back (llama_vocab_bos (vocab));
279- result.insert (result.end (), query.begin (), query.end ());
280- result.push_back (eos_token);
281- result.push_back (llama_vocab_sep (vocab));
282- result.insert (result.end (), doc.begin (), doc.end ());
283- result.push_back (eos_token);
284- } else {
285- // using prompt template
286- const char * prefix = llama_model_chat_template (model, " rerank_prefix" );
287- const char * suffix = llama_model_chat_template (model, " rerank_suffix" );
281+ tok.reserve (doc.size () + query.size () + 4 );
282+ tok.push_back (llama_vocab_bos (vocab));
283+ tok.insert (tok.end (), tok_query.begin (), tok_query.end ());
284+ tok.push_back (eos_token);
285+ tok.push_back (llama_vocab_sep (vocab));
286+ tok.insert (tok.end (), tok_doc.begin (), tok_doc.end ());
287+ tok.push_back (eos_token);
288288
289- if (prefix == NULL && suffix == NULL ) {
290- throw std::runtime_error (" Rerank prompt template not found in the model\n " );
291- }
289+ result.push_back (std::move (tok));
290+ } else {
291+ // using prompt template
292+ const char * tmpl = llama_model_chat_template (model, " rerank" );
293+ if (tmpl == nullptr ) {
294+ throw std::runtime_error (" model does not have rerank template" );
295+ }
292296
293- const llama_tokens prefix_tokens = prefix ? common_tokenize (vocab, prefix, true , false ) : llama_tokens () ;
294- const llama_tokens suffix_tokens = suffix ? common_tokenize (vocab, suffix, false , false ) : llama_tokens ();
295- result. reserve (prefix_tokens. size () + query. size () + suffix_tokens. size () + doc. size () );
296- result. insert (result. end (), prefix_tokens. begin (), prefix_tokens. end () );
297- result. insert (result. end (), query. begin (), query. end () );
298- result. insert (result. end (), suffix_tokens. begin (), suffix_tokens. end ( ));
299- result. insert (result. end (), doc. begin (), doc. end ());
297+ std::string prompt = tmpl ;
298+ // TODO: may not be efficient to call string_replace_all twice
299+ string_replace_all (prompt, " {query} " , query );
300+ string_replace_all (prompt, " {document} " , doc );
301+ llama_tokens tok = common_tokenize (vocab, prompt, true , false );
302+ result. push_back ( std::move (tok ));
303+ }
300304 }
301305
302306 return result;
0 commit comments