@@ -83,6 +83,102 @@ static void sigint_handler(int signo) {
8383}
8484#endif
8585
86+ class partial_formatter {
87+ public:
88+ enum output_type {
89+ CONTENT,
90+ REASONING,
91+ };
92+
93+ struct output {
94+ std::string formatted;
95+ output_type type;
96+ };
97+
98+ partial_formatter (const common_chat_syntax & syntax) : syntax_(syntax), had_reasoning_(false ) {}
99+
100+ std::vector<output> operator ()(const std::string & accumulated) {
101+ common_chat_msg next = common_chat_parse (accumulated, true , syntax_);
102+
103+ auto diffs = common_chat_msg_diff::compute_diffs (previous_, next);
104+ std::vector<output> result;
105+ for (const auto & diff : diffs) {
106+ if (!diff.reasoning_content_delta .empty ()) {
107+ result.push_back ({diff.reasoning_content_delta , REASONING});
108+ had_reasoning_ = true ;
109+ }
110+ if (!diff.content_delta .empty ()) {
111+ if (had_reasoning_) {
112+ result.push_back ({" \n " , REASONING});
113+ had_reasoning_ = false ;
114+ }
115+ result.push_back ({diff.content_delta , CONTENT});
116+ }
117+ }
118+ previous_ = next;
119+ return result;
120+ }
121+
122+ private:
123+ common_chat_syntax syntax_;
124+ common_chat_msg previous_;
125+ bool had_reasoning_;
126+ };
127+
128+ class chat_formatter {
129+ public:
130+ chat_formatter (
131+ std::vector<common_chat_msg> & chat_msgs,
132+ const common_chat_templates_ptr & chat_templates,
133+ const common_params & params)
134+ : chat_msgs_(chat_msgs),
135+ chat_templates_ (chat_templates),
136+ params_(params) {}
137+
138+ std::string operator ()(const std::string & role, const std::string & content) {
139+ common_chat_msg new_msg;
140+ new_msg.role = role;
141+ new_msg.content = content;
142+ chat_msgs_.push_back (new_msg);
143+
144+ common_chat_templates_inputs cinputs;
145+ cinputs.use_jinja = params_.use_jinja ;
146+ cinputs.messages = chat_msgs_;
147+ cinputs.add_generation_prompt = (role == " user" );
148+ cinputs.reasoning_format = params_.reasoning_format ;
149+
150+ cinputs.enable_thinking =
151+ params_.use_jinja && params_.reasoning_budget != 0 &&
152+ common_chat_templates_support_enable_thinking (chat_templates_.get ());
153+
154+ common_chat_params cparams = common_chat_templates_apply (chat_templates_.get (), cinputs);
155+
156+ if (!partial_formatter_ptr_ && params_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
157+ common_chat_syntax chat_syntax;
158+ chat_syntax.format = cparams.format ;
159+ chat_syntax.reasoning_format = params_.reasoning_format ;
160+ chat_syntax.thinking_forced_open = cparams.thinking_forced_open ;
161+ chat_syntax.parse_tool_calls = false ;
162+ partial_formatter_ptr_ = std::make_unique<partial_formatter>(chat_syntax);
163+ }
164+
165+ std::string formatted = cparams.prompt .substr (formatted_cumulative_.size ());
166+ formatted_cumulative_ = cparams.prompt ;
167+
168+ LOG_DBG (" formatted: '%s'\n " , formatted.c_str ());
169+ return formatted;
170+ }
171+
172+ partial_formatter * get_partial_formatter () { return partial_formatter_ptr_.get (); }
173+
174+ private:
175+ std::vector<common_chat_msg> & chat_msgs_;
176+ const common_chat_templates_ptr & chat_templates_;
177+ const common_params & params_;
178+ std::unique_ptr<partial_formatter> partial_formatter_ptr_;
179+ std::string formatted_cumulative_;
180+ };
181+
86182int main (int argc, char ** argv) {
87183 common_params params;
88184 g_params = ¶ms;
@@ -265,15 +361,7 @@ int main(int argc, char ** argv) {
265361 std::vector<llama_token> embd_inp;
266362
267363 bool waiting_for_first_input = false ;
268- auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
269- common_chat_msg new_msg;
270- new_msg.role = role;
271- new_msg.content = content;
272- auto formatted = common_chat_format_single (chat_templates.get (), chat_msgs, new_msg, role == " user" , g_params->use_jinja );
273- chat_msgs.push_back (new_msg);
274- LOG_DBG (" formatted: '%s'\n " , formatted.c_str ());
275- return formatted;
276- };
364+ chat_formatter chat_add_and_format (chat_msgs, chat_templates, params);
277365
278366 std::string prompt;
279367 {
@@ -709,6 +797,13 @@ int main(int argc, char ** argv) {
709797
710798 if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog (vocab, id)) {
711799 assistant_ss << common_token_to_piece (ctx, id, false );
800+
801+ if (auto * formatter = chat_add_and_format.get_partial_formatter ()) {
802+ auto outputs = (*formatter)(assistant_ss.str ());
803+ for (const auto & out : outputs) {
804+ LOG (" %s" , out.formatted .c_str ());
805+ }
806+ }
712807 }
713808
714809 // echo this to console
@@ -740,8 +835,9 @@ int main(int argc, char ** argv) {
740835 for (auto id : embd) {
741836 const std::string token_str = common_token_to_piece (ctx, id, params.special );
742837
743- // Console/Stream Output
744- LOG (" %s" , token_str.c_str ());
838+ if (!chat_add_and_format.get_partial_formatter () || assistant_ss.str ().empty ()) {
839+ LOG (" %s" , token_str.c_str ());
840+ }
745841
746842 // Record Displayed Tokens To Log
747843 // Note: Generated tokens are created one by one hence this check
0 commit comments