@@ -163,7 +163,7 @@ struct server_slot {
163163    int32_t  n_prompt_tokens           = 0 ;
164164    int32_t  n_prompt_tokens_processed = 0 ;
165165
166-     json prompt; //  can be either a string, array of strings  or array of token ids
166+     json prompt; //  can be either a string, array of strings, array of token ids,  or mixed  array of strings and  token ids
167167
168168    //  when a task is submitted, we first tokenize the prompt and store it here
169169    std::vector<llama_token> prompt_tokens;
@@ -975,16 +975,15 @@ struct server_context {
975975            }
976976
977977            if  ((prompt->is_string ()) ||
978-                 (prompt->is_array () &&  prompt->size () == 1  && prompt->at (0 ).is_string ()) ||
979-                 (prompt->is_array () && !prompt->empty ()     && prompt->at (0 ).is_number_integer ())) {
978+                 (prompt->is_array () && !prompt->empty () && (prompt->at (0 ).is_string () || prompt->at (0 ).is_number_integer ()))) {
980979                slot.prompt  = *prompt;
981980            } else  if  (prompt->is_array () && prompt->size () == 1  && prompt->at (0 ).is_array ()) {
982981                slot.prompt  = prompt->at (0 );
983982            } else  if  (prompt->is_array () && prompt->size () > 1 ) {
984983                //  array of strings
985984                for  (const  auto  & el : *prompt) {
986985                    if  (!el.is_string ()) {
987-                         send_error (task, " \" prompt\"  must be a string, an array of strings  or an  array of integers" 
986+                         send_error (task, " \" prompt\"  must be a string, an array of strings, an array of integers,  or a mixed  array of strings and  integers" 
988987                        return  false ;
989988                    }
990989                }
@@ -1062,18 +1061,10 @@ struct server_context {
10621061        }
10631062
10641063        {
1065-             //  These lines seem to force the clearing of sampler data between generations:
1066- 
1067-             //  if (slot.smpl != nullptr) {
1068-             //      gpt_sampler_free(slot.smpl);
1069-             //  }
1070-             //  slot.smpl = gpt_sampler_init(model, slot.sparams);
1071- 
1072-             //  Changed it to this so data could be maintained between generations:
1073- 
1074-             if  (slot.smpl  == nullptr ) {
1075-                 slot.smpl  = gpt_sampler_init (model, slot.sparams );
1064+             if  (slot.smpl  != nullptr ) {
1065+                 gpt_sampler_free (slot.smpl );
10761066            }
1067+             slot.smpl  = gpt_sampler_init (model, slot.sparams , slot.n_ctx );
10771068
10781069            if  (slot.smpl  == nullptr ) {
10791070                //  for now, the only error that may happen here is invalid grammar
@@ -1518,24 +1509,25 @@ struct server_context {
15181509            throw  std::runtime_error (error_msg);
15191510        }
15201511        json prompt = data.at (" prompt" 
1521-         //  if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1522-         if  (prompt.is_string () || (prompt.is_array () && !prompt.empty () && !prompt[0 ].is_array ())) {
1523-             bool  is_mixed = false ;
1524-             bool  has_string = prompt.is_string ();
1512+ 
1513+         auto  is_valid_singleton_array = [](const  json& arr) {
15251514            bool  has_number = false ;
1526-             if  (prompt.is_array ()) {
1527-                 for  (const  auto & elem : prompt) {
1528-                     if  (elem.is_string ()) has_string = true ;
1529-                     else  if  (elem.is_number ()) has_number = true ;
1530-                     if  (has_string && has_number) {
1531-                         is_mixed = true ;
1532-                         break ;
1533-                     }
1515+             for  (const  auto & elem : arr) {
1516+                 if  (elem.is_number ()) {
1517+                     has_number = true ;
1518+                 } else  if  (!elem.is_string ()) {
1519+                     return  false ;
15341520                }
15351521            }
1522+             return  has_number;
1523+         };
1524+ 
1525+         bool  is_singleton = prompt.is_string () || (prompt.is_array () && is_valid_singleton_array (prompt));
1526+ 
1527+         //  if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1528+         if  (prompt.is_string () || (prompt.is_array () && is_valid_singleton_array (prompt))) {
15361529            data[" index" 0 ;
15371530            create_task (data, false , nullptr );
1538-             SRV_DBG (" creating single%s prompt task\n " "  mixed" " " 
15391531        }
15401532        //  otherwise, it's a multiple-prompt task or a rerank task, we break it into smaller tasks
15411533        else  if  (prompt.is_array ()) {
@@ -2154,7 +2146,8 @@ struct server_context {
21542146                                GGML_ASSERT (slot.n_prompt_tokens  < slot.n_ctx );
21552147                            }
21562148
2157-                             // gpt_sampler_reset(slot.smpl);                     // This line is likely preventing sampler state from being maintained from generation to generation
2149+                             //  Should this be (re-)moved?
2150+                             gpt_sampler_reset (slot.smpl );
21582151
21592152                            if  (!slot.params .cache_prompt ) {
21602153                                slot.n_past_se  = 0 ;
@@ -2165,10 +2158,13 @@ struct server_context {
21652158                                //  reuse any previously computed tokens that are common with the new prompt
21662159                                slot.n_past  = common_part (slot.cache_tokens , prompt_tokens);
21672160
2161+                                 //  Not sure if the for loop below should happen in multiple places but for now I moved it
2162+                                 //  until after the entire prompt is processed so that sampling would happen consistently.
2163+ 
21682164                                //  push the prompt into the sampling context (do not apply grammar)
2169-                                 for  (int  i = 0 ; i < slot.n_past ; ++i) {
2170-                                     gpt_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
2171-                                 }
2165+                                 //   for (int i = 0; i < slot.n_past; ++i) {
2166+                                 //       gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2167+                                 //   }
21722168                            }
21732169                        }
21742170
@@ -2264,6 +2260,11 @@ struct server_context {
22642260
22652261                        GGML_ASSERT (batch.n_tokens  > 0 );
22662262
2263+                         //  Process all prompt tokens through sampler system
2264+                         for  (int  i = 0 ; i < slot.n_prompt_tokens ; ++i) {
2265+                             gpt_sampler_accept (slot.smpl , prompt_tokens[i], false );
2266+                         }
2267+ 
22672268                        //  extract the logits only for the last token
22682269                        batch.logits [batch.n_tokens  - 1 ] = true ;
22692270
0 commit comments