@@ -15,13 +15,13 @@ use tokio_stream::StreamExt;
1515use ollama_rs:: {
1616 generation:: {
1717 chat:: { request:: ChatMessageRequest , ChatMessage , ChatMessageResponse , MessageRole } ,
18- tools:: ToolInfo ,
18+ tools:: { ToolFunctionInfo , ToolInfo , ToolType } ,
1919 } ,
2020 models:: ModelOptions ,
2121 Ollama ,
2222} ;
2323
24- use serde_json:: { from_str, to_string, Value } ;
24+ use serde_json:: { from_str, json , to_string, Value } ;
2525use serde_norway:: { from_reader, from_str as from_yaml_str} ;
2626
2727use crate :: pdl:: ast:: {
@@ -45,6 +45,7 @@ struct Interpreter<'a> {
4545 scope : Vec < Scope > ,
4646 debug : bool ,
4747 emit : bool ,
48+ stream : bool ,
4849}
4950
5051impl < ' a > Interpreter < ' a > {
@@ -67,6 +68,7 @@ impl<'a> Interpreter<'a> {
6768 scope : vec ! [ Scope :: new( ) ] ,
6869 debug : false ,
6970 emit : true ,
71+ stream : true ,
7072 }
7173 }
7274
@@ -76,14 +78,13 @@ impl<'a> Interpreter<'a> {
7678 context : Context ,
7779 emit : bool ,
7880 ) -> Interpretation {
79- if self . debug {
81+ /* if self.debug {
8082 if let Some(scope) = self.scope.last() {
8183 if scope.len() > 0 {
8284 eprintln!("Run with Scope {:?}", scope);
8385 }
8486 }
85- }
86-
87+ } */
8788 let prior_emit = self . emit ;
8889 self . emit = emit;
8990
@@ -118,7 +119,9 @@ impl<'a> Interpreter<'a> {
118119 } ?;
119120
120121 if match program {
121- PdlBlock :: Call ( _) | PdlBlock :: Model ( _) => false ,
122+ PdlBlock :: Text ( _) | PdlBlock :: LastOf ( _) | PdlBlock :: Call ( _) | PdlBlock :: Model ( _) => {
123+ false
124+ }
122125 _ => self . emit ,
123126 } {
124127 println ! ( "{}" , pretty_print( & messages) ) ;
@@ -150,7 +153,7 @@ impl<'a> Interpreter<'a> {
150153 let backup = result. clone ( ) ;
151154 Ok ( from_str ( & result) . unwrap_or_else ( |err| {
152155 if self . debug {
153- eprintln ! ( "Treating as plain string {}" , & result) ;
156+ eprintln ! ( "Treating as plain string {}" , result) ;
154157 eprintln ! ( "... due to {}" , err) ;
155158 }
156159 backup. into ( )
@@ -332,7 +335,10 @@ impl<'a> Interpreter<'a> {
332335
333336 self . run ( & c. function . return_ , context. clone ( ) ) . await
334337 }
335- _ => Err ( Box :: from ( format ! ( "call of non-function {:?}" , & block. call) ) ) ,
338+ x => Err ( Box :: from ( format ! (
339+ "call of non-function {:?}->{:?}" ,
340+ block. call, x
341+ ) ) ) ,
336342 } ;
337343
338344 if let Some ( _) = block. args {
@@ -438,10 +444,36 @@ impl<'a> Interpreter<'a> {
438444 0.0
439445 } ;
440446
441- let tools = if let Some ( Value :: Array ( _tools) ) = parameters. get ( & "tools" . to_string ( ) ) {
442- // TODO
443- //tools.into_iter().map(|tool| function!()).collect()
444- vec ! [ ]
447+ let tools = if let Some ( Value :: Array ( tools) ) = parameters. get ( "tools" ) {
448+ tools
449+ . into_iter ( )
450+ . filter_map ( |tool| tool. get ( "function" ) )
451+ . filter_map ( |tool| {
452+ //from_str(&to_string(tool)?)
453+ match (
454+ tool. get ( "name" ) ,
455+ tool. get ( "description" ) ,
456+ tool. get ( "parameters" ) ,
457+ ) {
458+ (
459+ Some ( Value :: String ( name) ) ,
460+ Some ( Value :: String ( description) ) ,
461+ Some ( Value :: Object ( parameters) ) ,
462+ ) => Some ( ToolInfo {
463+ tool_type : ToolType :: Function ,
464+ function : ToolFunctionInfo {
465+ name : name. to_string ( ) ,
466+ description : description. to_string ( ) ,
467+ parameters : schemars:: schema_for_value!( parameters) ,
468+ } ,
469+ } ) ,
470+ _ => {
471+ eprintln ! ( "Error: tools do not satisfy schema {:?}" , tool) ;
472+ None
473+ }
474+ }
475+ } )
476+ . collect ( )
445477 } else {
446478 vec ! [ ]
447479 } ;
@@ -517,7 +549,7 @@ impl<'a> Interpreter<'a> {
517549 pdl_model
518550 if pdl_model. starts_with ( "ollama/" ) || pdl_model. starts_with ( "ollama_chat/" ) =>
519551 {
520- let ollama = Ollama :: default ( ) ;
552+ let mut ollama = Ollama :: default ( ) ;
521553 let model = if pdl_model. starts_with ( "ollama/" ) {
522554 & pdl_model[ 7 ..]
523555 } else {
@@ -526,7 +558,8 @@ impl<'a> Interpreter<'a> {
526558
527559 let ( options, tools) = self . to_ollama_model_options ( & block. parameters ) ;
528560 if self . debug {
529- println ! ( "Model options {:?}" , options) ;
561+ eprintln ! ( "Model options {:?} {:?}" , block. description, options) ;
562+ eprintln ! ( "Model tools {:?} {:?}" , block. description, tools) ;
530563 }
531564
532565 let input_messages = match & block. input {
@@ -542,7 +575,7 @@ impl<'a> Interpreter<'a> {
542575 Some ( x) => x,
543576 None => ( & ChatMessage :: user ( "" . into ( ) ) , & [ ] ) ,
544577 } ;
545- let history = Vec :: from ( history_slice) ;
578+ let mut history = Vec :: from ( history_slice) ;
546579 if self . debug {
547580 eprintln ! (
548581 "Ollama {:?} model={:?} prompt={:?} history={:?}" ,
@@ -560,52 +593,62 @@ impl<'a> Interpreter<'a> {
560593 let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
561594 . options ( options)
562595 . tools ( tools) ;
563- /* if we ever want non-streaming:
564- let res = ollama
565- .send_chat_messages_with_history(
566- &mut history,
567- req,
568- //ollama.generate(GenerationRequest::new(model.into(), prompt),
569- )
570- .await?;
571- // dbg!("Model result {:?}", &res);
572596
573- let mut trace = block.clone();
574- trace.pdl_result = Some(res.message.content.clone());
575-
576- if let Some(usage) = res.final_data {
577- trace.pdl_usage = Some(PdlUsage {
578- prompt_tokens: usage.prompt_eval_count,
579- prompt_nanos: usage.prompt_eval_duration,
580- completion_tokens: usage.eval_count,
581- completion_nanos: usage.eval_duration,
582- });
583- }
584- // dbg!(history);
585- Ok((vec![res.message], PdlBlock::Model(trace)))
586- */
587- let mut stream = ollama
588- . send_chat_messages_with_history_stream (
589- Arc :: new ( Mutex :: new ( history) ) ,
590- req,
591- //ollama.generate(GenerationRequest::new(model.into(), prompt),
592- )
593- . await ?;
594- // dbg!("Model result {:?}", &res);
595-
596- let mut last_res: Option < ChatMessageResponse > = None ;
597- let mut response_string = String :: new ( ) ;
598- let mut stdout = stdout ( ) ;
599- stdout. write_all ( b"\x1b [1mAssistant: \x1b [0m" ) . await ?;
600- while let Some ( Ok ( res) ) = stream. next ( ) . await {
601- stdout. write_all ( b"\x1b [32m" ) . await ?; // green
602- stdout. write_all ( res. message . content . as_bytes ( ) ) . await ?;
603- stdout. flush ( ) . await ?;
604- stdout. write_all ( b"\x1b [0m" ) . await ?; // reset color
605- response_string += res. message . content . as_str ( ) ;
606- last_res = Some ( res) ;
597+ let ( last_res, response_string) = if !self . stream {
598+ let res = ollama
599+ . send_chat_messages_with_history ( & mut history, req)
600+ . await ?;
601+ let response_string = res. message . content . clone ( ) ;
602+ print ! ( "{}" , response_string) ;
603+ ( Some ( res) , response_string)
604+ } else {
605+ let mut stream = ollama
606+ . send_chat_messages_with_history_stream (
607+ Arc :: new ( Mutex :: new ( history) ) ,
608+ req,
609+ //ollama.generate(GenerationRequest::new(model.into(), prompt),
610+ )
611+ . await ?;
612+ // dbg!("Model result {:?}", &res);
613+
614+ let emit = if let Some ( _) = & block. model_response {
615+ false
616+ } else {
617+ true
618+ } ;
619+
620+ let mut last_res: Option < ChatMessageResponse > = None ;
621+ let mut response_string = String :: new ( ) ;
622+ let mut stdout = stdout ( ) ;
623+ if emit {
624+ stdout. write_all ( b"\x1b [1mAssistant: \x1b [0m" ) . await ?;
625+ }
626+ while let Some ( Ok ( res) ) = stream. next ( ) . await {
627+ if emit {
628+ stdout. write_all ( b"\x1b [32m" ) . await ?; // green
629+ stdout. write_all ( res. message . content . as_bytes ( ) ) . await ?;
630+ stdout. flush ( ) . await ?;
631+ stdout. write_all ( b"\x1b [0m" ) . await ?; // reset color
632+ }
633+ response_string += res. message . content . as_str ( ) ;
634+ last_res = Some ( res) ;
635+ }
636+ if emit {
637+ stdout. write_all ( b"\n " ) . await ?;
638+ }
639+
640+ ( last_res, response_string)
641+ } ;
642+
643+ if let Some ( _) = & block. model_response {
644+ if let Some ( ref res) = last_res {
645+ self . def (
646+ & block. model_response ,
647+ & self . resultify_as_litellm ( & from_str ( & to_string ( & res) ?) ?) ,
648+ & None ,
649+ ) ?;
650+ }
607651 }
608- stdout. write_all ( b"\n " ) . await ?;
609652
610653 let mut trace = block. clone ( ) ;
611654 trace. pdl_result = Some ( response_string. clone ( ) ) ;
@@ -653,6 +696,15 @@ impl<'a> Interpreter<'a> {
653696 }
654697 }
655698
699+ /// Transform a JSON Value into a PdlResult object that is compatible with litellm's model response schema
700+ fn resultify_as_litellm ( & self , value : & Value ) -> PdlResult {
701+ self . resultify ( & json ! ( {
702+ "choices" : [
703+ value
704+ ]
705+ } ) )
706+ }
707+
656708 /// Run a PdlBlock::Data
657709 async fn run_data ( & mut self , block : & DataBlock , _context : Context ) -> Interpretation {
658710 if self . debug {
@@ -821,7 +873,7 @@ impl<'a> Interpreter<'a> {
821873 while let Some ( block) = iter. next ( ) {
822874 // run each element of the Text block
823875 let ( this_result, this_messages, trace) =
824- self . run_quiet ( & block, input_messages. clone ( ) ) . await ?;
876+ self . run ( & block, input_messages. clone ( ) ) . await ?;
825877 input_messages. extend ( this_messages. clone ( ) ) ;
826878 output_results. push ( this_result) ;
827879
@@ -918,17 +970,23 @@ impl<'a> Interpreter<'a> {
918970 }
919971}
920972
921- pub async fn run ( program : & PdlBlock , cwd : Option < PathBuf > , debug : bool ) -> Interpretation {
973+ pub async fn run (
974+ program : & PdlBlock ,
975+ cwd : Option < PathBuf > ,
976+ debug : bool ,
977+ stream : bool ,
978+ ) -> Interpretation {
922979 let mut interpreter = Interpreter :: new ( ) ;
923980 interpreter. debug = debug;
981+ interpreter. stream = stream;
924982 if let Some ( cwd) = cwd {
925983 interpreter. cwd = cwd
926984 } ;
927985 interpreter. run ( & program, vec ! [ ] ) . await
928986}
929987
930988pub fn run_sync ( program : & PdlBlock , cwd : Option < PathBuf > , debug : bool ) -> InterpretationSync {
931- tauri:: async_runtime:: block_on ( run ( program, cwd, debug) )
989+ tauri:: async_runtime:: block_on ( run ( program, cwd, debug, true ) )
932990 . map_err ( |err| Box :: < dyn Error > :: from ( err. to_string ( ) ) )
933991}
934992
@@ -938,22 +996,22 @@ pub fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
938996 . map_err ( |err| Box :: < dyn Error + Send + Sync > :: from ( err. to_string ( ) ) )
939997}
940998
941- pub async fn run_file ( source_file_path : & str , debug : bool ) -> Interpretation {
999+ pub async fn run_file ( source_file_path : & str , debug : bool , stream : bool ) -> Interpretation {
9421000 let path = PathBuf :: from ( source_file_path) ;
9431001 let cwd = path. parent ( ) . and_then ( |cwd| Some ( cwd. to_path_buf ( ) ) ) ;
9441002 let program = parse_file ( & path) ?;
9451003
9461004 crate :: pdl:: pull:: pull_if_needed ( & program) . await ?;
947- run ( & program, cwd, debug) . await
1005+ run ( & program, cwd, debug, stream ) . await
9481006}
9491007
950- pub fn run_file_sync ( source_file_path : & str , debug : bool ) -> InterpretationSync {
951- tauri:: async_runtime:: block_on ( run_file ( source_file_path, debug) )
1008+ pub fn run_file_sync ( source_file_path : & str , debug : bool , stream : bool ) -> InterpretationSync {
1009+ tauri:: async_runtime:: block_on ( run_file ( source_file_path, debug, stream ) )
9521010 . map_err ( |err| Box :: < dyn Error > :: from ( err. to_string ( ) ) )
9531011}
9541012
9551013pub async fn run_string ( source : & str , debug : bool ) -> Interpretation {
956- run ( & from_yaml_str ( source) ?, None , debug) . await
1014+ run ( & from_yaml_str ( source) ?, None , debug, true ) . await
9571015}
9581016
9591017pub async fn run_json ( source : Value , debug : bool ) -> Interpretation {
0 commit comments