@@ -86,28 +86,65 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
8686 m_state.m_phase = State::Phase::Generating;
8787}
8888
89- void Session::pushPrompt (std::span<const Token> prompt) {
89+ void Session::pushPrompt (std::span<const Token> prompt, std::span< const Token> postfix ) {
9090 if (m_state.m_phase != State::Phase::Generating) {
9191 throw_ex{} << " Session hasn't started yet" ;
9292 }
9393
9494 flushPendingState ();
9595
96- if (!prompt.empty ()) {
97- auto & sampler = m_instance.sampler ();
98- auto & model = m_instance.model ();
96+ if (prompt.empty () && postfix.empty ()) {
97+ throw_ex{} << " Prompt and postfix are empty" ;
98+ }
99+
100+ auto & model = m_instance.model ();
101+ auto & sampler = m_instance.sampler ();
102+
103+ // reset sampling and don't allow previous inputs to affect the generation
104+ sampler.reset ();
105+
106+ std::vector<Token> tokens;
107+ constexpr uint32_t maxAdditionalTokens = 4 ; // bos + fim_pre + fim_suf + fim_mid
108+ tokens.reserve (prompt.size () + postfix.size () + maxAdditionalTokens);
99109
100- // reset sampling and don't allow previous inputs to affect the generation
101- sampler.reset ();
110+ if (model.prefixInputsWithBos ()) {
111+ const auto tokenBos = llama_vocab_bos (model.vocab ().lvocab ());
112+ tokens.push_back (tokenBos);
113+ }
102114
103- if (model.prefixInputsWithBos ()) {
104- const auto tokenBos = llama_vocab_bos (model.vocab ().lvocab ());
105- // add bos token to the prompt
106- doDecode ({&tokenBos, 1 }, Source::InteractivePrompt);
115+ auto safeAddToken = [&](Token token, const std::string& tokenName) {
116+ if (token >= 0 ) {
117+ tokens.push_back (token);
118+ } else {
119+ LLAMA_LOG (Warning, " Model doesn't have a " , tokenName," token" );
107120 }
121+ };
122+
123+ if (!postfix.empty ()) {
124+ auto tokenFIMPre = llama_vocab_fim_pre (model.vocab ().lvocab ());
125+ safeAddToken (tokenFIMPre, " FIM Prefix" );
126+ }
127+
128+ if (!prompt.empty ()) {
129+ tokens.insert (tokens.end (), prompt.begin (), prompt.end ());
130+ }
131+
132+ if (!postfix.empty ()) {
133+ auto tokenFIMSuff = llama_vocab_fim_suf (model.vocab ().lvocab ());
134+ safeAddToken (tokenFIMSuff, " FIM Suffix" );
108135
109- doDecode (prompt, Source::InteractivePrompt);
136+ tokens.insert (tokens.end (), postfix.begin (), postfix.end ());
137+
138+ auto tkoenFIMMid = llama_vocab_fim_mid (model.vocab ().lvocab ());
139+ safeAddToken (tkoenFIMMid, " FIM Middle" );
110140 }
141+
142+ if (tokens.size () > m_state.maxTokens ) {
143+ const auto ctxLen = llama_n_ctx (m_ctx);
144+ throw_ex{} << " Prompt too long. Got " << tokens.size () << " tokens, max: " << ctxLen - 4 ;
145+ }
146+
147+ doDecode (tokens, Source::InteractivePrompt);
111148}
112149
113150Token Session::getToken () {
0 commit comments