@@ -595,6 +595,7 @@ class LlamaData {
595595 std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
596596 std::list<std::string> msg_strs;
597597 std::vector<char > fmtted;
598+ llama_pos n_past = 0 ;
598599
599600 int init (Opt & opt) {
600601 model = initialize_model (opt);
@@ -946,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
946947 }
947948
948949 // prepare a batch for the prompt
949- llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (tokens.data (), tokens.size (), 0 , 0 , true ));
950+ llama_batch_ext_ptr batch (llama_batch_ext_init_from_text (tokens.data (), tokens.size (), llama_data. n_past , 0 , true ));
950951 llama_token new_token_id;
951952 while (true ) {
952953 check_context_size (llama_data.context , batch);
@@ -955,6 +956,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
955956 return 1 ;
956957 }
957958
959+ llama_data.n_past += llama_batch_ext_get_n_tokens (batch.get ());
960+
958961 // sample the next token, check is it an end of generation?
959962 new_token_id = llama_sampler_sample (llama_data.sampler .get (), llama_data.context .get (), -1 );
960963 if (llama_vocab_is_eog (vocab, new_token_id)) {
@@ -969,7 +972,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
969972 print_word_and_concatenate_to_response (piece, response);
970973
971974 // prepare the next batch with the sampled token
972- batch.reset (llama_batch_ext_init_from_text (&new_token_id, 1 , 0 , 0 , true ));
975+ batch.reset (llama_batch_ext_init_from_text (&new_token_id, 1 , llama_data. n_past , 0 , true ));
973976 }
974977
975978 printf (LOG_COL_DEFAULT);
0 commit comments