@@ -90,41 +90,33 @@ void LLModel::prompt(const std::string &prompt,
9090        }
9191    }
9292
93-     auto  old_n_past = promptCtx.n_past ; //  prepare to fake n_past for tokenize
94- 
9593    //  tokenize the user prompt
9694    std::vector<Token> embd_inp;
9795    if  (placeholders.empty ()) {
9896        //  this is unusual, but well-defined
9997        std::cerr << __func__ << " : prompt template has no placeholder\n "  ;
100-         embd_inp = tokenize (promptCtx,  promptTemplate, true );
98+         embd_inp = tokenize (promptTemplate, true );
10199    } else  {
102100        //  template: beginning of user prompt
103101        const  auto  &phUser = placeholders[0 ];
104102        std::string userPrefix (phUser.prefix ());
105-         if  (!userPrefix.empty ()) {
106-             embd_inp = tokenize (promptCtx, userPrefix, true );
107-             promptCtx.n_past  += embd_inp.size ();
108-         }
103+         if  (!userPrefix.empty ())
104+             embd_inp = tokenize (userPrefix, true );
109105
110106        //  user input (shouldn't have special token processing)
111-         auto  tokens = tokenize (promptCtx,  prompt, special);
107+         auto  tokens = tokenize (prompt, special);
112108        embd_inp.insert (embd_inp.end (), tokens.begin (), tokens.end ());
113-         promptCtx.n_past  += tokens.size ();
114109
115110        //  template: end of user prompt + start of assistant prompt
116111        size_t  start = phUser.position () + phUser.length ();
117112        size_t  end = placeholders.size () >= 2  ? placeholders[1 ].position () : promptTemplate.length ();
118113        auto  userToAsst = promptTemplate.substr (start, end - start);
119114        if  (!userToAsst.empty ()) {
120-             tokens = tokenize (promptCtx,  userToAsst, true );
115+             tokens = tokenize (userToAsst, true );
121116            embd_inp.insert (embd_inp.end (), tokens.begin (), tokens.end ());
122-             promptCtx.n_past  += tokens.size ();
123117        }
124118    }
125119
126-     promptCtx.n_past  = old_n_past; //  restore n_past so decodePrompt can increment it
127- 
128120    //  decode the user prompt
129121    if  (!decodePrompt (promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp))
130122        return ; //  error
@@ -133,7 +125,7 @@ void LLModel::prompt(const std::string &prompt,
133125    if  (!fakeReply) {
134126        generateResponse (responseCallback, allowContextShift, promptCtx);
135127    } else  {
136-         embd_inp = tokenize (promptCtx,  *fakeReply, false );
128+         embd_inp = tokenize (*fakeReply, false );
137129        if  (!decodePrompt (promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, true ))
138130            return ; //  error
139131    }
@@ -148,7 +140,7 @@ void LLModel::prompt(const std::string &prompt,
148140        asstSuffix = " \n\n "  ; //  default to a blank link, good for e.g. Alpaca
149141    }
150142    if  (!asstSuffix.empty ()) {
151-         embd_inp = tokenize (promptCtx,  asstSuffix, true );
143+         embd_inp = tokenize (asstSuffix, true );
152144        decodePrompt (promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp);
153145    }
154146}
0 commit comments