diff --git a/pdl-live-react/src-tauri/src/compile/beeai.rs b/pdl-live-react/src-tauri/src/compile/beeai.rs index 9e5371afb..d46597ec4 100644 --- a/pdl-live-react/src-tauri/src/compile/beeai.rs +++ b/pdl-live-react/src-tauri/src/compile/beeai.rs @@ -12,9 +12,10 @@ use serde_json::{Map, Value, from_reader, json, to_string}; use tempfile::Builder; use crate::pdl::ast::{ - ArrayBlockBuilder, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString, MessageBlock, - MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlOptionalType, - PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock, TextBlockBuilder, + ArrayBlockBuilder, Block::*, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString, + MessageBlock, MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, + PdlBlock::Advanced, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, + TextBlock, TextBlockBuilder, }; use crate::pdl::pip::pip_install_if_needed; use crate::pdl::requirements::BEEAI_FRAMEWORK; @@ -190,7 +191,7 @@ fn with_tools( } fn call_tools(model: &String, parameters: &HashMap) -> PdlBlock { - let repeat = PdlBlock::Text(TextBlock { + let repeat = Advanced(Text(TextBlock { metadata: Some( MetadataBuilder::default() .description("Calling tool ${ tool.function.name }".to_string()) @@ -199,19 +200,19 @@ fn call_tools(model: &String, parameters: &HashMap) -> PdlBlock { ), role: None, parser: None, - text: vec![PdlBlock::Model( + text: vec![Advanced(Model( ModelBlockBuilder::default() .model(model.as_str()) .parameters(strip_nulls(parameters)) - .input(PdlBlock::Array( + .input(Advanced(Array( ArrayBlockBuilder::default() - .array(vec![PdlBlock::Message(MessageBlock { + .array(vec![Advanced(Message(MessageBlock { metadata: None, role: Role::Tool, defsite: None, name: Some("${ tool.function.name }".to_string()), tool_call_id: Some("${ tool.id }".to_string()), - content: Box::new(PdlBlock::Call(CallBlock { + content: Box::new(Advanced(Call(CallBlock { metadata: Some( MetadataBuilder::default() .defs(json_loads( @@ -226,15 +227,15 @@ fn call_tools(model: &String, parameters: &HashMap) -> PdlBlock { "${ pdl__tools[tool.function.name] }".to_string(), ), // look up tool in tool_declarations def (see below) args: Some("${ args }".into()), // invoke with arguments as specified by the model - })), - })]) + }))), + }))]) .build() .unwrap(), - )) + ))) .build() .unwrap(), - )], - }); + ))], + })); let mut for_ = HashMap::new(); for_.insert( @@ -248,11 +249,11 @@ fn call_tools(model: &String, parameters: &HashMap) -> PdlBlock { ); // response.choices[0].message.tool_calls - PdlBlock::Repeat(RepeatBlock { + Advanced(Repeat(RepeatBlock { metadata: None, for_: for_, repeat: Box::new(repeat), - }) + })) } fn json_loads( @@ -263,7 +264,7 @@ fn json_loads( let mut m = indexmap::IndexMap::new(); m.insert( outer_name.to_owned(), - PdlBlock::Text( + Advanced(Text( TextBlockBuilder::default() .text(vec![PdlBlock::String(format!( "{{\"{}\": {}}}", @@ -278,7 +279,7 @@ fn json_loads( .parser(PdlParser::Json) .build() .unwrap(), - ), + )), ); m } @@ -427,7 +428,7 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result 0 { @@ -537,7 +538,7 @@ asyncio.run(invoke()) closure_name.clone(), PdlBlock::Function(FunctionBlock { function: HashMap::new(), - return_: Box::new(PdlBlock::Text(TextBlock { + return_: Box::new(Advanced(Text(TextBlock { metadata: Some( MetadataBuilder::default() .description(format!("Model call {}", &model)) @@ -547,10 +548,10 @@ asyncio.run(invoke()) role: None, parser: None, text: model_call, - })), + }))), }), ); - PdlBlock::Text(TextBlock { + Advanced(Text(TextBlock { metadata: Some( MetadataBuilder::default() .description("Model call wrapper".to_string()) @@ -560,11 +561,11 @@ asyncio.run(invoke()) ), role: None, parser: None, - text: vec![PdlBlock::Call(CallBlock::new(format!( + text: vec![Advanced(Call(CallBlock::new(format!( "${{ {} }}", closure_name - )))], - }) + ))))], + })) }, ) .collect::>(); @@ -579,19 +580,19 @@ asyncio.run(invoke()) let mut defs = indexmap::IndexMap::new(); defs.insert( "pdl__tools".to_string(), - PdlBlock::Object(ObjectBlock { + Advanced(Object(ObjectBlock { object: tool_declarations, - }), + })), ); metadata.defs(defs); } - let pdl: PdlBlock = PdlBlock::Text(TextBlock { + let pdl: PdlBlock = Advanced(Text(TextBlock { metadata: Some(metadata.build().unwrap()), role: None, parser: None, text: body, - }); + })); Ok(pdl) } diff --git a/pdl-live-react/src-tauri/src/pdl/ast.rs b/pdl-live-react/src-tauri/src/pdl/ast.rs index 1fe8d76a5..52b7eb7e9 100644 --- a/pdl-live-react/src-tauri/src/pdl/ast.rs +++ b/pdl-live-react/src-tauri/src/pdl/ast.rs @@ -183,7 +183,7 @@ impl SequencingBlock for LastOfBlock { &self.parser } fn to_block(&self) -> PdlBlock { - PdlBlock::LastOf(self.clone()) + PdlBlock::Advanced(Block::LastOf(self.clone())) } fn result_for(&self, output_results: Vec) -> PdlResult { match output_results.last() { @@ -243,7 +243,7 @@ impl SequencingBlock for TextBlock { &self.parser } fn to_block(&self) -> PdlBlock { - PdlBlock::Text(self.clone()) + PdlBlock::Advanced(Block::Text(self.clone())) } fn result_for(&self, output_results: Vec) -> PdlResult { PdlResult::String( @@ -596,8 +596,15 @@ pub enum PdlBlock { Number(Number), String(String), Function(FunctionBlock), + Advanced(Block), + // must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs` + Empty(EmptyBlock), +} - // the rest have Metadata; TODO refactor to make this more explicit +/// A PDL block that has structure and metadata +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Block { If(IfBlock), Import(ImportBlock), Include(IncludeBlock), @@ -612,9 +619,6 @@ pub enum PdlBlock { Model(ModelBlock), LastOf(LastOfBlock), Text(TextBlock), - - // must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs` - Empty(EmptyBlock), } impl From for PdlBlock { diff --git a/pdl-live-react/src-tauri/src/pdl/extract.rs b/pdl-live-react/src-tauri/src/pdl/extract.rs index 61af006e1..62f9c4cc9 100644 --- a/pdl-live-react/src-tauri/src/pdl/extract.rs +++ b/pdl-live-react/src-tauri/src/pdl/extract.rs @@ -1,4 +1,4 @@ -use crate::pdl::ast::{Metadata, PdlBlock}; +use crate::pdl::ast::{Block::*, Metadata, PdlBlock, PdlBlock::Advanced}; /// Extract models referenced by the programs pub fn extract_models(program: &PdlBlock) -> Vec { @@ -20,18 +20,26 @@ pub fn extract_values(program: &PdlBlock, field: &str) -> Vec { /// Take one Yaml fragment and produce a vector of the string-valued entries of the given field fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec) { match program { - PdlBlock::Model(b) => values.push(b.model.clone()), - PdlBlock::Repeat(b) => { + PdlBlock::Empty(b) => { + b.defs + .values() + .for_each(|p| extract_values_iter(p, field, values)); + } + PdlBlock::Function(b) => { + extract_values_iter(&b.return_, field, values); + } + Advanced(Model(b)) => values.push(b.model.clone()), + Advanced(Repeat(b)) => { extract_values_iter(&b.repeat, field, values); } - PdlBlock::Message(b) => { + Advanced(Message(b)) => { extract_values_iter(&b.content, field, values); } - PdlBlock::Array(b) => b + Advanced(Array(b)) => b .array .iter() .for_each(|p| extract_values_iter(p, field, values)), - PdlBlock::Text(b) => { + Advanced(Text(b)) => { b.text .iter() .for_each(|p| extract_values_iter(p, field, values)); @@ -43,7 +51,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec .for_each(|p| extract_values_iter(p, field, values)); } } - PdlBlock::LastOf(b) => { + Advanced(LastOf(b)) => { b.last_of .iter() .for_each(|p| extract_values_iter(p, field, values)); @@ -55,7 +63,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec .for_each(|p| extract_values_iter(p, field, values)); } } - PdlBlock::If(b) => { + Advanced(If(b)) => { extract_values_iter(&b.then, field, values); if let Some(else_) = &b.else_ { extract_values_iter(else_, field, values); @@ -68,20 +76,11 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec .for_each(|p| extract_values_iter(p, field, values)); } } - PdlBlock::Empty(b) => { - b.defs - .values() - .for_each(|p| extract_values_iter(p, field, values)); - } - PdlBlock::Object(b) => b + Advanced(Object(b)) => b .object .values() .for_each(|p| extract_values_iter(p, field, values)), - PdlBlock::Function(b) => { - extract_values_iter(&b.return_, field, values); - } - _ => {} } } diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index 6c09697fc..43d15469b 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -21,10 +21,10 @@ use serde_json::{Value, from_str, json, to_string}; use serde_norway::{from_reader, from_str as from_yaml_str}; use crate::pdl::ast::{ - ArrayBlock, CallBlock, Closure, DataBlock, EmptyBlock, EvalsTo, Expr, FunctionBlock, IfBlock, - ImportBlock, IncludeBlock, ListOrString, MessageBlock, ModelBlock, ObjectBlock, PdlBlock, - PdlParser, PdlResult, PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock, Role, Scope, - SequencingBlock, StringOrBoolean, StringOrNull, + ArrayBlock, Block::*, CallBlock, Closure, DataBlock, EmptyBlock, EvalsTo, Expr, FunctionBlock, + IfBlock, ImportBlock, IncludeBlock, ListOrString, MessageBlock, ModelBlock, ObjectBlock, + PdlBlock, PdlBlock::Advanced, PdlParser, PdlResult, PdlUsage, PythonCodeBlock, ReadBlock, + RepeatBlock, Role, Scope, SequencingBlock, StringOrBoolean, StringOrNull, }; type Messages = Vec; @@ -137,31 +137,27 @@ impl<'a> Interpreter<'a> { PdlBlock::Function(f.clone()), )), PdlBlock::String(s) => self.run_string(s, state).await, - PdlBlock::Call(block) => self.run_call(block, state).await, PdlBlock::Empty(block) => self.run_empty(block, state).await, - PdlBlock::If(block) => self.run_if(block, state).await, - PdlBlock::Import(block) => self.run_import(block, state).await, - PdlBlock::Include(block) => self.run_include(block, state).await, - PdlBlock::Model(block) => self.run_model(block, state).await, - PdlBlock::Data(block) => self.run_data(block, state).await, - PdlBlock::Object(block) => self.run_object(block, state).await, - PdlBlock::PythonCode(block) => self.run_python_code(block, state).await, - PdlBlock::Read(block) => self.run_read(block, state).await, - PdlBlock::Repeat(block) => self.run_repeat(block, state).await, - PdlBlock::LastOf(block) => self.run_sequence(block, state).await, - PdlBlock::Text(block) => self.run_sequence(block, state).await, - PdlBlock::Array(block) => self.run_array(block, state).await, - PdlBlock::Message(block) => self.run_message(block, state).await, + Advanced(Call(block)) => self.run_call(block, state).await, + Advanced(If(block)) => self.run_if(block, state).await, + Advanced(Import(block)) => self.run_import(block, state).await, + Advanced(Include(block)) => self.run_include(block, state).await, + Advanced(Model(block)) => self.run_model(block, state).await, + Advanced(Data(block)) => self.run_data(block, state).await, + Advanced(Object(block)) => self.run_object(block, state).await, + Advanced(PythonCode(block)) => self.run_python_code(block, state).await, + Advanced(Read(block)) => self.run_read(block, state).await, + Advanced(Repeat(block)) => self.run_repeat(block, state).await, + Advanced(LastOf(block)) => self.run_sequence(block, state).await, + Advanced(Text(block)) => self.run_sequence(block, state).await, + Advanced(Array(block)) => self.run_array(block, state).await, + Advanced(Message(block)) => self.run_message(block, state).await, }?; if match program { - PdlBlock::Message(_) - | PdlBlock::Text(_) - | PdlBlock::Import(_) - | PdlBlock::Include(_) - | PdlBlock::LastOf(_) - | PdlBlock::Call(_) - | PdlBlock::Model(_) => false, + Advanced(Message(_)) | Advanced(Text(_)) | Advanced(Import(_)) + | Advanced(Include(_)) | Advanced(LastOf(_)) | Advanced(Call(_)) + | Advanced(Model(_)) => false, _ => state.emit, } { println!("{}", pretty_print(&messages)); @@ -424,7 +420,7 @@ impl<'a> Interpreter<'a> { Ok(( result, vec![ChatMessage::user(buffer)], - PdlBlock::Read(trace), + Advanced(Read(trace)), )) } @@ -485,7 +481,7 @@ impl<'a> Interpreter<'a> { } else { match &block.else_ { Some(else_block) => self.run_quiet(&else_block, state).await, - None => Ok(("".into(), vec![], PdlBlock::If(block.clone()))), + None => Ok(("".into(), vec![], Advanced(If(block.clone())))), } } } @@ -628,7 +624,7 @@ impl<'a> Interpreter<'a> { } }?; let messages = vec![ChatMessage::user(result_string.as_str().to_string())]; - let trace = PdlBlock::PythonCode(block.clone()); + let trace = Advanced(PythonCode(block.clone())); Ok((messages[0].content.clone().into(), messages, trace)) } Err(_) => Err(Box::from( @@ -764,11 +760,11 @@ impl<'a> Interpreter<'a> { Ok(( res.message.content.into(), output_messages, - PdlBlock::Model(trace), + Advanced(Model(trace)), )) } else { // nothing came out of the model - Ok(("".into(), vec![], PdlBlock::Model(trace))) + Ok(("".into(), vec![], Advanced(Model(trace)))) } // dbg!(history); } @@ -791,7 +787,7 @@ impl<'a> Interpreter<'a> { state, true, )?; - Ok((result, vec![], PdlBlock::Data(trace))) + Ok((result, vec![], Advanced(Data(trace)))) } else { let result = self.def( &block.metadata.as_ref().and_then(|m| m.def.clone()), @@ -801,7 +797,7 @@ impl<'a> Interpreter<'a> { true, )?; trace.data = from_str(to_string(&result)?.as_str())?; - Ok((result, vec![], PdlBlock::Data(trace))) + Ok((result, vec![], Advanced(Data(trace)))) } } @@ -825,7 +821,7 @@ impl<'a> Interpreter<'a> { Ok(( PdlResult::Dict(result_map), messages, - PdlBlock::Object(ObjectBlock { object: trace_map }), + Advanced(Object(ObjectBlock { object: trace_map })), )) } @@ -872,7 +868,7 @@ impl<'a> Interpreter<'a> { Ok(( PdlResult::List(results), messages, - PdlBlock::Repeat(block.clone()), + Advanced(Repeat(block.clone())), )) } @@ -993,7 +989,7 @@ impl<'a> Interpreter<'a> { Ok(( PdlResult::List(result_items), all_messages, - PdlBlock::Array(trace), + Advanced(Array(trace)), )) } @@ -1031,14 +1027,14 @@ impl<'a> Interpreter<'a> { .into_iter() .map(|m| ChatMessage::new(self.to_ollama_role(&block.role), m.content)) .collect(), - PdlBlock::Message(MessageBlock { + Advanced(Message(MessageBlock { metadata: block.metadata.clone(), role: block.role.clone(), content: Box::new(content_trace), name: name, defsite: None, tool_call_id: tool_call_id, - }), + })), )) } } diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs index 49c495bb5..fcb10546d 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs @@ -5,7 +5,7 @@ mod tests { use serde_json::json; use crate::pdl::{ - ast::{ModelBlockBuilder, PdlBlock, Scope}, + ast::{Block::*, ModelBlockBuilder, PdlBlock, PdlBlock::Advanced, Scope}, interpreter::{RunOptions, load_scope, run_json_sync as run_json, run_sync as run}, }; @@ -61,12 +61,12 @@ mod tests { #[test] fn single_model_via_input_string() -> Result<(), Box> { let (_, messages, _) = run( - &PdlBlock::Model( + &Advanced(Model( ModelBlockBuilder::default() .model(DEFAULT_MODEL) .input(Box::from(PdlBlock::String("hello".to_string()))) .build()?, - ), + )), None, streaming(), initial_scope(),