@@ -26,8 +26,9 @@ use serde_json::{from_str, to_string, Number, Value};
2626use serde_norway:: { from_reader, from_str as from_yaml_str} ;
2727
2828use crate :: pdl:: ast:: {
29- CallBlock , FunctionBlock , IfBlock , ListOrString , ModelBlock , ObjectBlock , PdlBlock , PdlParser ,
30- PdlUsage , PythonCodeBlock , ReadBlock , RepeatBlock , Role , StringOrBoolean , TextBlock ,
29+ ArrayBlock , CallBlock , FunctionBlock , IfBlock , ListOrString , MessageBlock , ModelBlock ,
30+ ObjectBlock , PdlBlock , PdlParser , PdlUsage , PythonCodeBlock , ReadBlock , RepeatBlock , Role ,
31+ StringOrBoolean , TextBlock ,
3132} ;
3233
3334#[ derive( Serialize , Deserialize , Debug , Clone ) ]
@@ -48,23 +49,11 @@ pub enum PdlResult {
4849 Dict ( HashMap < String , PdlResult > ) ,
4950}
5051impl :: std:: fmt:: Display for PdlResult {
51- // This trait requires `fmt` with this exact signature.
5252 fn fmt ( & self , f : & mut :: std:: fmt:: Formatter ) -> :: std:: fmt:: Result {
5353 let s = to_string ( & self ) . unwrap ( ) ; // TODO: .map_err(|e| e.to_string())?;
5454 write ! ( f, "{}" , s)
5555 }
5656}
57- /*impl From<&Value> for PdlResult {
58- fn from(v: &Value) -> Self {
59- match v {
60- Value::Bool(b) => PdlResult::Bool(*b),
61- Value::String(s) => PdlResult::String(s.clone()),
62- Value::Number(n) => PdlResult::Number(n.clone()),
63- Value::Array(a) => PdlResult::List(a.into_iter().map(|v| v.into())),
64- Value::Object(m) => PdlResult::Dict(m.into_iter().map(|(k,v)| (k, v.into()))),
65- }
66- }
67- }*/
6857impl From < & str > for PdlResult {
6958 fn from ( s : & str ) -> Self {
7059 PdlResult :: String ( s. to_string ( ) )
@@ -93,7 +82,6 @@ struct Interpreter<'a> {
9382 cwd : PathBuf ,
9483 // id_stack: Vec<String>,
9584 jinja_env : Environment < ' a > ,
96- // rt: Runtime,
9785 scope : Vec < Scope > ,
9886 debug : bool ,
9987 emit : bool ,
@@ -116,7 +104,6 @@ impl<'a> Interpreter<'a> {
116104 cwd : current_dir ( ) . unwrap_or ( PathBuf :: from ( "/" ) ) ,
117105 // id_stack: vec![],
118106 jinja_env : jinja_env,
119- // rt: Runtime::new().unwrap(),
120107 scope : vec ! [ Scope :: new( ) ] ,
121108 debug : false ,
122109 emit : true ,
@@ -160,14 +147,15 @@ impl<'a> Interpreter<'a> {
160147 PdlBlock :: Read ( block) => self . run_read ( block, context) . await ,
161148 PdlBlock :: Repeat ( block) => self . run_repeat ( block, context) . await ,
162149 PdlBlock :: Text ( block) => self . run_text ( block, context) . await ,
150+ PdlBlock :: Array ( block) => self . run_array ( block, context) . await ,
151+ PdlBlock :: Message ( block) => self . run_message ( block, context) . await ,
163152 _ => Err ( Box :: from ( format ! ( "Unsupported block {:?}" , program) ) ) ,
164153 } ?;
165154
166155 if match program {
167- PdlBlock :: Call ( _) | PdlBlock :: Text ( _) => false ,
156+ PdlBlock :: Call ( _) | PdlBlock :: Model ( _ ) | PdlBlock :: Text ( _) => false ,
168157 _ => self . emit ,
169158 } {
170- // eprintln!("{:?}", program);
171159 println ! ( "{}" , pretty_print( & messages) ) ;
172160 }
173161 self . emit = prior_emit;
@@ -204,6 +192,17 @@ impl<'a> Interpreter<'a> {
204192 } ) )
205193 }
206194
195+ /// Evaluate String as a Jinja2 expression, expecting a string in response
196+ fn eval_to_string ( & self , expr : & String ) -> Result < String , PdlError > {
197+ match self . eval ( expr) ? {
198+ PdlResult :: String ( s) => Ok ( s) ,
199+ x => Err ( Box :: from ( format ! (
200+ "Expression {expr} evaluated to non-string {:?}" ,
201+ x
202+ ) ) ) ,
203+ }
204+ }
205+
207206 fn eval_complex ( & self , expr : & Value ) -> Result < PdlResult , PdlError > {
208207 match expr {
209208 Value :: Null => Ok ( "" . into ( ) ) ,
@@ -499,7 +498,7 @@ impl<'a> Interpreter<'a> {
499498 println ! ( "Model options {:?}" , options) ;
500499 }
501500
502- let messages = match & block. input {
501+ let input_messages = match & block. input {
503502 Some ( input) => {
504503 // TODO ignoring result, trace
505504 let ( _result, messages, _trace) = self . run_quiet ( & * input, context) . await ?;
@@ -508,7 +507,7 @@ impl<'a> Interpreter<'a> {
508507 None => context,
509508 } ;
510509 let ( prompt, history_slice) : ( & ChatMessage , & [ ChatMessage ] ) =
511- match messages . split_last ( ) {
510+ match input_messages . split_last ( ) {
512511 Some ( x) => x,
513512 None => ( & ChatMessage :: user ( "" . into ( ) ) , & [ ] ) ,
514513 } ;
@@ -523,6 +522,10 @@ impl<'a> Interpreter<'a> {
523522 ) ;
524523 }
525524
525+ if self . emit {
526+ println ! ( "{}" , pretty_print( & input_messages) ) ;
527+ }
528+
526529 let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
527530 . options ( options)
528531 . tools ( tools) ;
@@ -571,6 +574,7 @@ impl<'a> Interpreter<'a> {
571574 response_string += res. message . content . as_str ( ) ;
572575 last_res = Some ( res) ;
573576 }
577+ stdout. write_all ( b"\n " ) . await ?;
574578
575579 let mut trace = block. clone ( ) ;
576580 trace. pdl_result = Some ( response_string. clone ( ) ) ;
@@ -584,11 +588,10 @@ impl<'a> Interpreter<'a> {
584588 completion_nanos : usage. eval_duration ,
585589 } ) ;
586590 }
587- let mut message = res. message . clone ( ) ;
588- message. content = response_string;
591+ let output_messages = vec ! [ ChatMessage :: assistant( response_string) ] ;
589592 Ok ( (
590- message. content . clone ( ) . into ( ) ,
591- vec ! [ message ] ,
593+ res . message . content . into ( ) ,
594+ output_messages ,
592595 PdlBlock :: Model ( trace) ,
593596 ) )
594597 } else {
@@ -673,7 +676,7 @@ impl<'a> Interpreter<'a> {
673676 match role {
674677 Role :: User => MessageRole :: User ,
675678 Role :: Assistant => MessageRole :: Assistant ,
676- Role :: System => MessageRole :: Assistant ,
679+ Role :: System => MessageRole :: System ,
677680 Role :: Tool => MessageRole :: Tool ,
678681 }
679682 }
@@ -786,6 +789,70 @@ impl<'a> Interpreter<'a> {
786789 PdlBlock :: Text ( trace) ,
787790 ) )
788791 }
792+
793+ async fn run_array ( & mut self , block : & ArrayBlock , context : Context ) -> Interpretation {
794+ let mut result_items = vec ! [ ] ;
795+ let mut all_messages = vec ! [ ] ;
796+ let mut trace_items = vec ! [ ] ;
797+
798+ let mut iter = block. array . iter ( ) ;
799+ while let Some ( item) = iter. next ( ) {
800+ // TODO accumulate messages
801+ let ( result, messages, trace) = self . run_quiet ( item, context. clone ( ) ) . await ?;
802+ result_items. push ( result) ;
803+ all_messages. extend ( messages) ;
804+ trace_items. push ( trace) ;
805+ }
806+
807+ Ok ( (
808+ PdlResult :: List ( result_items) ,
809+ all_messages,
810+ PdlBlock :: Array ( ArrayBlock { array : trace_items } ) ,
811+ ) )
812+ }
813+
814+ async fn run_message ( & mut self , block : & MessageBlock , context : Context ) -> Interpretation {
815+ let ( content_result, content_messages, content_trace) =
816+ self . run ( & block. content , context) . await ?;
817+ let name = if let Some ( name) = & block. name {
818+ Some ( self . eval_to_string ( & name) ?)
819+ } else {
820+ None
821+ } ;
822+ let tool_call_id = if let Some ( tool_call_id) = & block. tool_call_id {
823+ Some ( self . eval_to_string ( & tool_call_id) ?)
824+ } else {
825+ None
826+ } ;
827+
828+ let mut dict: HashMap < String , PdlResult > = HashMap :: new ( ) ;
829+ dict. insert ( "role" . into ( ) , PdlResult :: String ( to_string ( & block. role ) ?) ) ;
830+ dict. insert ( "content" . into ( ) , content_result) ;
831+ if let Some ( name) = & name {
832+ dict. insert ( "name" . into ( ) , PdlResult :: String ( name. clone ( ) ) ) ;
833+ }
834+ if let Some ( tool_call_id) = & tool_call_id {
835+ dict. insert (
836+ "tool_call_id" . into ( ) ,
837+ PdlResult :: String ( tool_call_id. clone ( ) ) ,
838+ ) ;
839+ }
840+
841+ Ok ( (
842+ PdlResult :: Dict ( dict) ,
843+ content_messages
844+ . into_iter ( )
845+ . map ( |m| ChatMessage :: new ( self . to_ollama_role ( & block. role ) , m. content ) )
846+ . collect ( ) ,
847+ PdlBlock :: Message ( MessageBlock {
848+ role : block. role . clone ( ) ,
849+ content : Box :: new ( content_trace) ,
850+ description : block. description . clone ( ) ,
851+ name : name,
852+ tool_call_id : tool_call_id,
853+ } ) ,
854+ ) )
855+ }
789856}
790857
791858pub async fn run ( program : & PdlBlock , cwd : Option < PathBuf > , debug : bool ) -> Interpretation {
0 commit comments