@@ -5,6 +5,7 @@ use ::std::error::Error;
55use :: std:: fs:: File ;
66// use ::std::path::PathBuf;
77
8+ use async_recursion:: async_recursion;
89use minijinja:: { syntax:: SyntaxConfig , Environment } ;
910use owo_colors:: OwoColorize ;
1011
@@ -27,7 +28,8 @@ use crate::pdl::ast::{
2728
2829type Context = Vec < ChatMessage > ;
2930type Scope = HashMap < String , Value > ;
30- type Interpretation = Result < ( Context , PdlBlock ) , Box < dyn Error > > ;
31+ type Interpretation = Result < ( Context , PdlBlock ) , Box < dyn Error + Send + Sync > > ;
32+ type InterpretationSync = Result < ( Context , PdlBlock ) , Box < dyn Error > > ;
3133
3234struct Interpreter < ' a > {
3335 // batch: u32,
@@ -65,7 +67,7 @@ impl<'a> Interpreter<'a> {
6567 }
6668 }
6769
68- fn run_with_emit (
70+ async fn run_with_emit (
6971 & mut self ,
7072 program : & PdlBlock ,
7173 context : Context ,
@@ -75,12 +77,12 @@ impl<'a> Interpreter<'a> {
7577 self . emit = emit;
7678
7779 let ( messages, trace) = match program {
78- PdlBlock :: String ( s) => self . run_string ( s, context) ,
79- PdlBlock :: Call ( block) => self . run_call ( block, context) ,
80- PdlBlock :: PythonCode ( block) => self . run_python_code ( block, context) ,
81- PdlBlock :: Model ( block) => self . run_model ( block, context) ,
82- PdlBlock :: Repeat ( block) => self . run_repeat ( block, context) ,
83- PdlBlock :: Text ( block) => self . run_text ( block, context) ,
80+ PdlBlock :: String ( s) => self . run_string ( s, context) . await ,
81+ PdlBlock :: Call ( block) => self . run_call ( block, context) . await ,
82+ PdlBlock :: PythonCode ( block) => self . run_python_code ( block, context) . await ,
83+ PdlBlock :: Model ( block) => self . run_model ( block, context) . await ,
84+ PdlBlock :: Repeat ( block) => self . run_repeat ( block, context) . await ,
85+ PdlBlock :: Text ( block) => self . run_text ( block, context) . await ,
8486 _ => Err ( Box :: from ( format ! ( "Unsupported block {:?}" , program) ) ) ,
8587 } ?;
8688
@@ -96,19 +98,21 @@ impl<'a> Interpreter<'a> {
9698 Ok ( ( messages, trace) )
9799 }
98100
99- fn run_quiet ( & mut self , program : & PdlBlock , context : Context ) -> Interpretation {
100- self . run_with_emit ( program, context, false )
101+ #[ async_recursion]
102+ async fn run_quiet ( & mut self , program : & PdlBlock , context : Context ) -> Interpretation {
103+ self . run_with_emit ( program, context, false ) . await
101104 }
102105
103- fn run ( & mut self , program : & PdlBlock , context : Context ) -> Interpretation {
104- self . run_with_emit ( program, context, self . emit )
106+ #[ async_recursion]
107+ async fn run ( & mut self , program : & PdlBlock , context : Context ) -> Interpretation {
108+ self . run_with_emit ( program, context, self . emit ) . await
105109 }
106110
107111 // Evaluate as a Jinja2 expression
108112 fn eval < T : serde:: de:: DeserializeOwned + :: std:: convert:: From < String > > (
109113 & self ,
110114 expr : & String ,
111- ) -> Result < T , Box < dyn Error > > {
115+ ) -> Result < T , Box < dyn Error + Send + Sync > > {
112116 let result = self
113117 . jinja_env
114118 . render_str ( expr. as_str ( ) , self . scope . last ( ) . unwrap_or ( & HashMap :: new ( ) ) ) ?;
@@ -127,7 +131,7 @@ impl<'a> Interpreter<'a> {
127131 }
128132
129133 // Run a PdlBlock::String
130- fn run_string ( & self , msg : & String , _context : Context ) -> Interpretation {
134+ async fn run_string ( & self , msg : & String , _context : Context ) -> Interpretation {
131135 let trace = self . eval :: < PdlBlock > ( msg) ?;
132136 if self . debug {
133137 eprintln ! ( "String {} -> {:?}" , msg, trace) ;
@@ -142,7 +146,7 @@ impl<'a> Interpreter<'a> {
142146 }
143147
144148 // Run a PdlBlock::Call
145- fn run_call ( & mut self , block : & PdlCallBlock , context : Context ) -> Interpretation {
149+ async fn run_call ( & mut self , block : & PdlCallBlock , context : Context ) -> Interpretation {
146150 if self . debug {
147151 eprintln ! ( "Call {:?}({:?})" , block. call, block. args) ;
148152 }
@@ -154,7 +158,7 @@ impl<'a> Interpreter<'a> {
154158 // args was a string that eval'd to an Object
155159 Value :: Object ( m) => Ok ( Some ( self . to_pdl ( & m) ) ) ,
156160 // args was a string that eval'd to something we don't understand
157- y => Err ( Box :: < dyn Error > :: from ( format ! (
161+ y => Err ( Box :: < dyn Error + Send + Sync > :: from ( format ! (
158162 "Invalid arguments to call {:?}" ,
159163 y
160164 ) ) ) ,
@@ -170,7 +174,7 @@ impl<'a> Interpreter<'a> {
170174 self . extend_scope_with_map ( & args) ;
171175
172176 let res = match self . eval :: < PdlBlock > ( & block. call ) ? {
173- PdlBlock :: Function ( f) => self . run ( & f. return_ , context. clone ( ) ) ,
177+ PdlBlock :: Function ( f) => self . run ( & f. return_ , context. clone ( ) ) . await ,
174178 _ => Err ( Box :: from ( format ! ( "call of non-function {:?}" , & block. call) ) ) ,
175179 } ;
176180 self . scope . pop ( ) ;
@@ -230,7 +234,11 @@ impl<'a> Interpreter<'a> {
230234 }
231235
232236 // Run a PdlBlock::PythonCode
233- fn run_python_code ( & mut self , block : & PdlPythonCodeBlock , context : Context ) -> Interpretation {
237+ async fn run_python_code (
238+ & mut self ,
239+ block : & PdlPythonCodeBlock ,
240+ context : Context ,
241+ ) -> Interpretation {
234242 use rustpython_vm as vm;
235243 vm:: Interpreter :: without_stdlib ( Default :: default ( ) ) . enter ( |vm| -> Interpretation {
236244 let scope = vm. new_scope_with_builtins ( ) ;
@@ -275,7 +283,7 @@ impl<'a> Interpreter<'a> {
275283 }
276284
277285 // Run a PdlBlock::Model
278- fn run_model ( & mut self , block : & PdlModelBlock , context : Context ) -> Interpretation {
286+ async fn run_model ( & mut self , block : & PdlModelBlock , context : Context ) -> Interpretation {
279287 match & block. model {
280288 pdl_model
281289 if pdl_model. starts_with ( "ollama/" ) || pdl_model. starts_with ( "ollama_chat/" ) =>
@@ -295,7 +303,7 @@ impl<'a> Interpreter<'a> {
295303 let messages = match & block. input {
296304 Some ( input) => {
297305 // TODO ignoring trace
298- let ( messages, _trace) = self . run_quiet ( & * input, context) ?;
306+ let ( messages, _trace) = self . run_quiet ( & * input, context) . await ?;
299307 messages
300308 }
301309 None => context,
@@ -310,24 +318,27 @@ impl<'a> Interpreter<'a> {
310318 eprintln ! (
311319 "Ollama {:?} model={:?} prompt={:?} history={:?}" ,
312320 block. description. clone( ) . unwrap_or( "" . into( ) ) ,
313- & block. model,
314- & prompt,
315- & history
321+ block. model,
322+ prompt,
323+ history
316324 ) ;
317325 }
318326
319327 let req = ChatMessageRequest :: new ( model. into ( ) , vec ! [ prompt. clone( ) ] )
320328 . options ( options)
321329 . tools ( tools) ;
322- let res = /*self.rt.*/ tauri:: async_runtime:: block_on ( ollama. send_chat_messages_with_history (
323- & mut history,
324- req,
325- //ollama.generate(GenerationRequest::new(model.into(), prompt),
326- ) ) ?;
330+ let res = ollama
331+ . send_chat_messages_with_history (
332+ & mut history,
333+ req,
334+ //ollama.generate(GenerationRequest::new(model.into(), prompt),
335+ )
336+ . await ?;
327337 // dbg!("Model result {:?}", &res);
328338
329339 let mut trace = block. clone ( ) ;
330340 trace. pdl_result = Some ( res. message . content . clone ( ) ) ;
341+
331342 if let Some ( usage) = res. final_data {
332343 trace. pdl_usage = Some ( PdlUsage {
333344 prompt_tokens : usage. prompt_eval_count ,
@@ -344,7 +355,7 @@ impl<'a> Interpreter<'a> {
344355 }
345356
346357 // Run a PdlBlock::Repeat
347- fn run_repeat ( & mut self , block : & PdlRepeatBlock , _context : Context ) -> Interpretation {
358+ async fn run_repeat ( & mut self , block : & PdlRepeatBlock , _context : Context ) -> Interpretation {
348359 let for_ = block
349360 . for_
350361 . iter ( )
@@ -368,7 +379,11 @@ impl<'a> Interpreter<'a> {
368379 }
369380 }
370381
371- fn parse_result ( & self , parser : & PdlParser , result : & String ) -> Result < Value , Box < dyn Error > > {
382+ fn parse_result (
383+ & self ,
384+ parser : & PdlParser ,
385+ result : & String ,
386+ ) -> Result < Value , Box < dyn Error + Send + Sync > > {
372387 match parser {
373388 PdlParser :: Json => Ok ( from_str ( result) ?) ,
374389 }
@@ -395,7 +410,7 @@ impl<'a> Interpreter<'a> {
395410 }
396411
397412 // Run a PdlBlock::Text
398- fn run_text ( & mut self , block : & PdlTextBlock , context : Context ) -> Interpretation {
413+ async fn run_text ( & mut self , block : & PdlTextBlock , context : Context ) -> Interpretation {
399414 if self . debug {
400415 eprintln ! (
401416 "Text {:?}" ,
@@ -411,14 +426,14 @@ impl<'a> Interpreter<'a> {
411426 let mut output_blocks = vec ! [ ] ;
412427
413428 self . extend_scope_with_map ( & block. defs ) ;
414- block. text . iter ( ) . try_for_each ( |block| {
429+ let mut iter = block. text . iter ( ) ;
430+ while let Some ( block) = iter. next ( ) {
415431 // run each element of the Text block
416- let ( this_messages, trace) = self . run ( & block, input_messages. clone ( ) ) ?;
432+ let ( this_messages, trace) = self . run ( & block, input_messages. clone ( ) ) . await ?;
417433 input_messages. extend ( this_messages. clone ( ) ) ;
418434 output_messages. extend ( this_messages) ;
419435 output_blocks. push ( trace) ;
420- Ok :: < ( ) , Box < dyn Error > > ( ( ) )
421- } ) ?;
436+ }
422437 self . scope . pop ( ) ;
423438
424439 let mut trace = block. clone ( ) ;
@@ -455,22 +470,37 @@ impl<'a> Interpreter<'a> {
455470 }
456471}
457472
458- pub fn run ( program : & PdlBlock , debug : bool ) -> Interpretation {
473+ pub async fn run ( program : & PdlBlock , debug : bool ) -> Interpretation {
459474 let mut interpreter = Interpreter :: new ( ) ;
460475 interpreter. debug = debug;
461- interpreter. run ( & program, vec ! [ ] )
476+ interpreter. run ( & program, vec ! [ ] ) . await
477+ }
478+
479+ pub fn run_sync ( program : & PdlBlock , debug : bool ) -> InterpretationSync {
480+ tauri:: async_runtime:: block_on ( run ( program, debug) )
481+ . map_err ( |err| Box :: < dyn :: std:: error:: Error > :: from ( err. to_string ( ) ) )
482+ }
483+
484+ pub async fn run_file ( source_file_path : & str , debug : bool ) -> Interpretation {
485+ run ( & from_reader ( File :: open ( source_file_path) ?) ?, debug) . await
486+ }
487+
488+ pub fn run_file_sync ( source_file_path : & str , debug : bool ) -> InterpretationSync {
489+ tauri:: async_runtime:: block_on ( run_file ( source_file_path, debug) )
490+ . map_err ( |err| Box :: < dyn :: std:: error:: Error > :: from ( err. to_string ( ) ) )
462491}
463492
464- pub fn run_file ( source_file_path : & str , debug : bool ) -> Interpretation {
465- run ( & from_reader ( File :: open ( source_file_path ) ? ) ? , debug)
493+ pub async fn run_string ( source : & str , debug : bool ) -> Interpretation {
494+ run ( & from_yaml_str ( source ) ? , debug) . await
466495}
467496
468- pub fn run_string ( source : & str , debug : bool ) -> Interpretation {
469- run ( & from_yaml_str ( source) ?, debug)
497+ pub async fn run_json ( source : Value , debug : bool ) -> Interpretation {
498+ run_string ( & to_string ( & source) ?, debug) . await
470499}
471500
472- pub fn run_json ( source : Value , debug : bool ) -> Interpretation {
473- run_string ( & to_string ( & source) ?, debug)
501+ pub fn run_json_sync ( source : Value , debug : bool ) -> InterpretationSync {
502+ tauri:: async_runtime:: block_on ( run_json ( source, debug) )
503+ . map_err ( |err| Box :: < dyn :: std:: error:: Error > :: from ( err. to_string ( ) ) )
474504}
475505
476506pub fn pretty_print ( messages : & Vec < ChatMessage > ) -> String {
0 commit comments