@@ -261,6 +261,18 @@ impl<'a> Interpreter<'a> {
261261 }
262262 }
263263
264+ // TODO how can we better cope with the expected String return?
265+ fn eval_string_to_string (
266+ & self ,
267+ expr : & EvalsTo < String , String > ,
268+ state : & State ,
269+ ) -> Result < PdlResult , PdlError > {
270+ match expr {
271+ EvalsTo :: Const ( s) | EvalsTo :: Jinja ( s) => self . eval ( s, state) ,
272+ EvalsTo :: Expr ( e) => self . eval ( & e. pdl_expr , state) ,
273+ }
274+ }
275+
264276 /// Evaluate an Expr to a bool
265277 fn eval_to_bool (
266278 & self ,
@@ -722,140 +734,149 @@ impl<'a> Interpreter<'a> {
722734 } )
723735 }
724736
725- /// Run a PdlBlock::Model
726- async fn run_model (
737+ async fn run_ollama_model (
727738 & mut self ,
739+ pdl_model : String ,
728740 block : & ModelBlock ,
729741 metadata : & Metadata ,
730742 state : & mut State ,
731743 ) -> BodyInterpretation {
732- match & block. model {
733- pdl_model
734- if pdl_model. starts_with ( "ollama/" ) || pdl_model. starts_with ( "ollama_chat/" ) =>
735- {
736- let mut ollama = Ollama :: default ( ) ;
737- let model = if pdl_model. starts_with ( "ollama/" ) {
738- & pdl_model[ 7 ..]
739- } else {
740- & pdl_model[ 12 ..]
741- } ;
744+ let mut ollama = Ollama :: default ( ) ;
745+ let model = if pdl_model. starts_with ( "ollama/" ) {
746+ & pdl_model[ 7 ..]
747+ } else {
748+ & pdl_model[ 12 ..]
749+ } ;
742750
743- let ( options, tools) = self . to_ollama_model_options ( & block. parameters ) ;
744- if self . options . debug {
745- eprintln ! ( "Model options {:?} {:?}" , metadata. description, options) ;
746- eprintln ! ( "Model tools {:?} {:?}" , metadata. description, tools) ;
747- }
751+ let ( options, tools) = self . to_ollama_model_options ( & block. parameters ) ;
752+ if self . options . debug {
753+ eprintln ! ( "Model options {:?} {:?}" , metadata. description, options) ;
754+ eprintln ! ( "Model tools {:?} {:?}" , metadata. description, tools) ;
755+ }
748756
749- // The input messages to the model is either:
750- // a) block.input, if given
751- // b) the current state's accumulated messages
752- let input_messages = match & block. input {
753- Some ( input) => {
754- // TODO ignoring result, trace
755- let ( _result, messages, _trace) = self . run_quiet ( & * input, state) . await ?;
756- messages
757- }
758- None => state. messages . clone ( ) ,
759- } ;
760- let ( prompt, history_slice) : ( & ChatMessage , & [ ChatMessage ] ) =
761- match input_messages. split_last ( ) {
762- Some ( x) => x,
763- None => ( & ChatMessage :: user ( "" . into ( ) ) , & [ ] ) ,
764- } ;
765- let mut history = Vec :: from ( history_slice) ;
766- if self . options . debug {
767- eprintln ! (
768- "Ollama {:?} model={:?} prompt={:?} history={:?}" ,
769- metadata. description, block. model, prompt, history
770- ) ;
771- }
757+ // The input messages to the model is either:
758+ // a) block.input, if given
759+ // b) the current state's accumulated messages
760+ let input_messages = match & block. input {
761+ Some ( input) => {
762+ // TODO ignoring result, trace
763+ let ( _result, messages, _trace) = self . run_quiet ( & * input, state) . await ?;
764+ messages
765+ }
766+ None => state. messages . clone ( ) ,
767+ } ;
768+ let ( prompt, history_slice) : ( & ChatMessage , & [ ChatMessage ] ) =
769+ match input_messages. split_last ( ) {
770+ Some ( x) => x,
771+ None => ( & ChatMessage :: user ( "" . into ( ) ) , & [ ] ) ,
772+ } ;
773+ let mut history = Vec :: from ( history_slice) ;
774+ if self . options . debug {
775+ eprintln ! (
776+ "Ollama {:?} model={:?} prompt={:?} history={:?}" ,
777+ metadata. description, block. model, prompt, history
778+ ) ;
779+ }
772780
773- //if state.emit {
774- //println!("{}", pretty_print(&input_messages));
775- //}
776-
777- let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
778- . options ( options)
779- . tools ( tools) ;
780-
781- let ( last_res, response_string) = if !self . options . stream {
782- let res = ollama
783- . send_chat_messages_with_history ( & mut history, req)
784- . await ?;
785- let response_string = res. message . content . clone ( ) ;
786- print ! ( "{}" , response_string) ;
787- ( Some ( res) , response_string)
788- } else {
789- let mut stream = ollama
790- . send_chat_messages_with_history_stream (
791- :: std:: sync:: Arc :: new ( :: std:: sync:: Mutex :: new ( history) ) ,
792- req,
793- //ollama.generate(GenerationRequest::new(model.into(), prompt),
794- )
795- . await ?;
796- // dbg!("Model result {:?}", &res);
797-
798- let emit = if let Some ( _) = & block. model_response {
799- false
800- } else {
801- true
802- } ;
803-
804- let mut last_res: Option < ChatMessageResponse > = None ;
805- let mut response_string = String :: new ( ) ;
806- let mut stdout = stdout ( ) ;
807- if emit {
808- stdout. write_all ( b"\x1b [1mAssistant: \x1b [0m" ) . await ?;
809- }
810- while let Some ( Ok ( res) ) = stream. next ( ) . await {
811- if emit {
812- stdout. write_all ( b"\x1b [32m" ) . await ?; // green
813- stdout. write_all ( res. message . content . as_bytes ( ) ) . await ?;
814- stdout. flush ( ) . await ?;
815- stdout. write_all ( b"\x1b [0m" ) . await ?; // reset color
816- }
817- response_string += res. message . content . as_str ( ) ;
818- last_res = Some ( res) ;
819- }
820- if emit {
821- stdout. write_all ( b"\n " ) . await ?;
822- }
781+ //if state.emit {
782+ //println!("{}", pretty_print(&input_messages));
783+ //}
784+
785+ let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
786+ . options ( options)
787+ . tools ( tools) ;
788+
789+ let ( last_res, response_string) = if !self . options . stream {
790+ let res = ollama
791+ . send_chat_messages_with_history ( & mut history, req)
792+ . await ?;
793+ let response_string = res. message . content . clone ( ) ;
794+ print ! ( "{}" , response_string) ;
795+ ( Some ( res) , response_string)
796+ } else {
797+ let mut stream = ollama
798+ . send_chat_messages_with_history_stream (
799+ :: std:: sync:: Arc :: new ( :: std:: sync:: Mutex :: new ( history) ) ,
800+ req,
801+ //ollama.generate(GenerationRequest::new(model.into(), prompt),
802+ )
803+ . await ?;
804+ // dbg!("Model result {:?}", &res);
823805
824- ( last_res, response_string)
825- } ;
826-
827- if let Some ( _) = & block. model_response {
828- if let Some ( ref res) = last_res {
829- self . def (
830- & block. model_response ,
831- & resultify_as_litellm ( & from_str ( & to_string ( & res) ?) ?) ,
832- & None ,
833- state,
834- true ,
835- ) ?;
836- }
837- }
806+ let emit = if let Some ( _) = & block. model_response {
807+ false
808+ } else {
809+ true
810+ } ;
838811
839- let mut trace = block. clone ( ) ;
840- if let Some ( res) = last_res {
841- if let Some ( usage) = res. final_data {
842- trace. pdl_usage = Some ( PdlUsage {
843- prompt_tokens : usage. prompt_eval_count ,
844- prompt_nanos : Some ( usage. prompt_eval_duration ) ,
845- completion_tokens : usage. eval_count ,
846- completion_nanos : Some ( usage. eval_duration ) ,
847- } ) ;
848- }
849- let output_messages = vec ! [ ChatMessage :: assistant( response_string) ] ;
850- Ok ( ( res. message . content . into ( ) , output_messages, Model ( trace) ) )
851- } else {
852- // nothing came out of the model
853- Ok ( ( "" . into ( ) , vec ! [ ] , Model ( trace) ) )
812+ let mut last_res: Option < ChatMessageResponse > = None ;
813+ let mut response_string = String :: new ( ) ;
814+ let mut stdout = stdout ( ) ;
815+ if emit {
816+ stdout. write_all ( b"\x1b [1mAssistant: \x1b [0m" ) . await ?;
817+ }
818+ while let Some ( Ok ( res) ) = stream. next ( ) . await {
819+ if emit {
820+ stdout. write_all ( b"\x1b [32m" ) . await ?; // green
821+ stdout. write_all ( res. message . content . as_bytes ( ) ) . await ?;
822+ stdout. flush ( ) . await ?;
823+ stdout. write_all ( b"\x1b [0m" ) . await ?; // reset color
854824 }
855- // dbg!(history);
825+ response_string += res. message . content . as_str ( ) ;
826+ last_res = Some ( res) ;
827+ }
828+ if emit {
829+ stdout. write_all ( b"\n " ) . await ?;
830+ }
831+
832+ ( last_res, response_string)
833+ } ;
834+
835+ if let Some ( _) = & block. model_response {
836+ if let Some ( ref res) = last_res {
837+ self . def (
838+ & block. model_response ,
839+ & resultify_as_litellm ( & from_str ( & to_string ( & res) ?) ?) ,
840+ & None ,
841+ state,
842+ true ,
843+ ) ?;
856844 }
857- _ => Err ( Box :: from ( format ! ( "Unsupported model {}" , block. model) ) ) ,
858845 }
846+
847+ let mut trace = block. clone ( ) ;
848+ if let Some ( res) = last_res {
849+ if let Some ( usage) = res. final_data {
850+ trace. pdl_usage = Some ( PdlUsage {
851+ prompt_tokens : usage. prompt_eval_count ,
852+ prompt_nanos : Some ( usage. prompt_eval_duration ) ,
853+ completion_tokens : usage. eval_count ,
854+ completion_nanos : Some ( usage. eval_duration ) ,
855+ } ) ;
856+ }
857+ let output_messages = vec ! [ ChatMessage :: assistant( response_string) ] ;
858+ Ok ( ( res. message . content . into ( ) , output_messages, Model ( trace) ) )
859+ } else {
860+ // nothing came out of the model
861+ Ok ( ( "" . into ( ) , vec ! [ ] , Model ( trace) ) )
862+ }
863+ // dbg!(history);
864+ }
865+
866+ /// Run a PdlBlock::Model
867+ async fn run_model (
868+ & mut self ,
869+ block : & ModelBlock ,
870+ metadata : & Metadata ,
871+ state : & mut State ,
872+ ) -> BodyInterpretation {
873+ if let PdlResult :: String ( s) = self . eval_string_to_string ( & block. model , state) ? {
874+ if s. starts_with ( "ollama/" ) || s. starts_with ( "ollama_chat/" ) {
875+ return self . run_ollama_model ( s, block, metadata, state) . await ;
876+ }
877+ }
878+
879+ Err ( Box :: from ( format ! ( "Unsupported model {:?}" , block. model) ) )
859880 }
860881
861882 /// Run a PdlBlock::Data
0 commit comments