11// use ::std::cell::LazyCell;
22use :: std:: collections:: HashMap ;
3+ use std:: sync:: { Arc , Mutex } ;
34// use ::std::env::current_dir;
45use :: std:: error:: Error ;
56use :: std:: fs:: File ;
@@ -8,10 +9,12 @@ use ::std::fs::File;
89use async_recursion:: async_recursion;
910use minijinja:: { syntax:: SyntaxConfig , Environment } ;
1011use owo_colors:: OwoColorize ;
12+ use tokio:: io:: { stdout, AsyncWriteExt } ;
13+ use tokio_stream:: StreamExt ;
1114
1215use ollama_rs:: {
1316 generation:: {
14- chat:: { request:: ChatMessageRequest , ChatMessage , MessageRole } ,
17+ chat:: { request:: ChatMessageRequest , ChatMessage , ChatMessageResponse , MessageRole } ,
1518 tools:: ToolInfo ,
1619 } ,
1720 models:: ModelOptions ,
@@ -288,7 +291,7 @@ impl<'a> Interpreter<'a> {
288291 pdl_model
289292 if pdl_model. starts_with ( "ollama/" ) || pdl_model. starts_with ( "ollama_chat/" ) =>
290293 {
291- let mut ollama = Ollama :: default ( ) ;
294+ let ollama = Ollama :: default ( ) ;
292295 let model = if pdl_model. starts_with ( "ollama/" ) {
293296 & pdl_model[ 7 ..]
294297 } else {
@@ -313,7 +316,7 @@ impl<'a> Interpreter<'a> {
313316 Some ( x) => x,
314317 None => ( & ChatMessage :: user ( "" . into ( ) ) , & [ ] ) ,
315318 } ;
316- let mut history = Vec :: from ( history_slice) ;
319+ let history = Vec :: from ( history_slice) ;
317320 if self . debug {
318321 eprintln ! (
319322 "Ollama {:?} model={:?} prompt={:?} history={:?}" ,
@@ -327,6 +330,7 @@ impl<'a> Interpreter<'a> {
327330 let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
328331 . options ( options)
329332 . tools ( tools) ;
333+ /* if we ever want non-streaming:
330334 let res = ollama
331335 .send_chat_messages_with_history(
332336 &mut history,
@@ -349,6 +353,48 @@ impl<'a> Interpreter<'a> {
349353 }
350354 // dbg!(history);
351355 Ok((vec![res.message], PdlBlock::Model(trace)))
356+ */
357+ let mut stream = ollama
358+ . send_chat_messages_with_history_stream (
359+ Arc :: new ( Mutex :: new ( history) ) ,
360+ req,
361+ //ollama.generate(GenerationRequest::new(model.into(), prompt),
362+ )
363+ . await ?;
364+ // dbg!("Model result {:?}", &res);
365+
366+ let mut last_res: Option < ChatMessageResponse > = None ;
367+ let mut response_string = String :: new ( ) ;
368+ let mut stdout = stdout ( ) ;
369+ stdout. write_all ( b"\x1b [1mAssistant: \x1b [0m" ) . await ?;
370+ while let Some ( Ok ( res) ) = stream. next ( ) . await {
371+ stdout. write_all ( b"\x1b [32m" ) . await ?; // green
372+ stdout. write_all ( res. message . content . as_bytes ( ) ) . await ?;
373+ stdout. flush ( ) . await ?;
374+ stdout. write_all ( b"\x1b [0m" ) . await ?; // reset color
375+ response_string += res. message . content . as_str ( ) ;
376+ last_res = Some ( res) ;
377+ }
378+
379+ let mut trace = block. clone ( ) ;
380+ trace. pdl_result = Some ( response_string. clone ( ) ) ;
381+
382+ if let Some ( res) = last_res {
383+ if let Some ( usage) = res. final_data {
384+ trace. pdl_usage = Some ( PdlUsage {
385+ prompt_tokens : usage. prompt_eval_count ,
386+ prompt_nanos : usage. prompt_eval_duration ,
387+ completion_tokens : usage. eval_count ,
388+ completion_nanos : usage. eval_duration ,
389+ } ) ;
390+ }
391+ let mut message = res. message . clone ( ) ;
392+ message. content = response_string;
393+ Ok ( ( vec ! [ message] , PdlBlock :: Model ( trace) ) )
394+ } else {
395+ Ok ( ( vec ! [ ] , PdlBlock :: Model ( trace) ) )
396+ }
397+ // dbg!(history);
352398 }
353399 _ => Err ( Box :: from ( format ! ( "Unsupported model {}" , block. model) ) ) ,
354400 }
0 commit comments