@@ -95,8 +95,13 @@ int main(int argc, char ** argv) {
9595 params.n_batch = params.n_ctx ;
9696 }
9797
98- // For non-causal models, batch size must be equal to ubatch size
99- params.n_ubatch = params.n_batch ;
98+ // for non-causal models, batch size must be equal to ubatch size
99+ if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
100+ params.n_ubatch = params.n_batch ;
101+ }
102+
103+ // get max number of sequences per batch
104+ const int n_seq_max = llama_max_parallel_sequences ();
100105
101106 llama_backend_init ();
102107 llama_numa_init (params.numa );
@@ -144,6 +149,7 @@ int main(int argc, char ** argv) {
144149 // get added sep and eos token, if any
145150 const std::string added_sep_token = llama_vocab_get_add_sep (vocab) ? llama_vocab_get_text (vocab, llama_vocab_sep (vocab)) : " " ;
146151 const std::string added_eos_token = llama_vocab_get_add_eos (vocab) ? llama_vocab_get_text (vocab, llama_vocab_eos (vocab)) : " " ;
152+ const char * rerank_prompt = llama_model_chat_template (model, " rerank" );
147153
148154 // tokenize the prompts and trim
149155 std::vector<std::vector<int32_t >> inputs;
@@ -153,21 +159,28 @@ int main(int argc, char ** argv) {
153159 // split classification pairs and insert expected separator tokens
154160 if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find (params.cls_sep ) != std::string::npos) {
155161 std::vector<std::string> pairs = split_lines (prompt, params.cls_sep );
156- std::string final_prompt;
157-
158- for (size_t i = 0 ; i < pairs.size (); i++) {
159- final_prompt += pairs[i];
160- if (i != pairs.size () - 1 ) {
161- if (!added_eos_token.empty ()) {
162- final_prompt += added_eos_token;
163- }
164- if (!added_sep_token.empty ()) {
165- final_prompt += added_sep_token;
162+ if (rerank_prompt != nullptr ) {
163+ const std::string query = pairs[0 ];
164+ const std::string doc = pairs[1 ];
165+ std::string final_prompt = rerank_prompt;
166+ string_replace_all (final_prompt, " {query}" , query);
167+ string_replace_all (final_prompt, " {document}" , doc );
168+ inp = common_tokenize (vocab, final_prompt, true , true );
169+ } else {
170+ std::string final_prompt;
171+ for (size_t i = 0 ; i < pairs.size (); i++) {
172+ final_prompt += pairs[i];
173+ if (i != pairs.size () - 1 ) {
174+ if (!added_eos_token.empty ()) {
175+ final_prompt += added_eos_token;
176+ }
177+ if (!added_sep_token.empty ()) {
178+ final_prompt += added_sep_token;
179+ }
166180 }
167181 }
182+ inp = common_tokenize (ctx, final_prompt, true , true );
168183 }
169-
170- inp = common_tokenize (ctx, final_prompt, true , true );
171184 } else {
172185 inp = common_tokenize (ctx, prompt, true , true );
173186 }
@@ -229,7 +242,7 @@ int main(int argc, char ** argv) {
229242 const uint64_t n_toks = inp.size ();
230243
231244 // encode if at capacity
232- if (batch.n_tokens + n_toks > n_batch) {
245+ if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max ) {
233246 float * out = emb + e * n_embd;
234247 batch_decode (ctx, batch, out, s, n_embd, params.embd_normalize );
235248 e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
0 commit comments