Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions pdl-live-react/src-tauri/src/compile/beeai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use serde_json::{Map, Value, from_reader, json, to_string};
use tempfile::Builder;

use crate::pdl::ast::{
ArrayBlockBuilder, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString, MessageBlock,
MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlOptionalType,
PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock, TextBlockBuilder,
ArrayBlockBuilder, Block::*, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString,
MessageBlock, MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock,
PdlBlock::Advanced, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role,
TextBlock, TextBlockBuilder,
};
use crate::pdl::pip::pip_install_if_needed;
use crate::pdl::requirements::BEEAI_FRAMEWORK;
Expand Down Expand Up @@ -190,7 +191,7 @@ fn with_tools(
}

fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
let repeat = PdlBlock::Text(TextBlock {
let repeat = Advanced(Text(TextBlock {
metadata: Some(
MetadataBuilder::default()
.description("Calling tool ${ tool.function.name }".to_string())
Expand All @@ -199,19 +200,19 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
),
role: None,
parser: None,
text: vec![PdlBlock::Model(
text: vec![Advanced(Model(
ModelBlockBuilder::default()
.model(model.as_str())
.parameters(strip_nulls(parameters))
.input(PdlBlock::Array(
.input(Advanced(Array(
ArrayBlockBuilder::default()
.array(vec![PdlBlock::Message(MessageBlock {
.array(vec![Advanced(Message(MessageBlock {
metadata: None,
role: Role::Tool,
defsite: None,
name: Some("${ tool.function.name }".to_string()),
tool_call_id: Some("${ tool.id }".to_string()),
content: Box::new(PdlBlock::Call(CallBlock {
content: Box::new(Advanced(Call(CallBlock {
metadata: Some(
MetadataBuilder::default()
.defs(json_loads(
Expand All @@ -226,15 +227,15 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
"${ pdl__tools[tool.function.name] }".to_string(),
), // look up tool in tool_declarations def (see below)
args: Some("${ args }".into()), // invoke with arguments as specified by the model
})),
})])
}))),
}))])
.build()
.unwrap(),
))
)))
.build()
.unwrap(),
)],
});
))],
}));

let mut for_ = HashMap::new();
for_.insert(
Expand All @@ -248,11 +249,11 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
);

// response.choices[0].message.tool_calls
PdlBlock::Repeat(RepeatBlock {
Advanced(Repeat(RepeatBlock {
metadata: None,
for_: for_,
repeat: Box::new(repeat),
})
}))
}

fn json_loads(
Expand All @@ -263,7 +264,7 @@ fn json_loads(
let mut m = indexmap::IndexMap::new();
m.insert(
outer_name.to_owned(),
PdlBlock::Text(
Advanced(Text(
TextBlockBuilder::default()
.text(vec![PdlBlock::String(format!(
"{{\"{}\": {}}}",
Expand All @@ -278,7 +279,7 @@ fn json_loads(
.parser(PdlParser::Json)
.build()
.unwrap(),
),
)),
);
m
}
Expand Down Expand Up @@ -427,7 +428,7 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result<PdlBlock, Box<dyn
tool_name.clone(),
PdlBlock::Function(FunctionBlock {
function: schema,
return_: Box::new(PdlBlock::PythonCode(PythonCodeBlock {
return_: Box::new(Advanced(PythonCode(PythonCodeBlock {
// tool function definition
metadata: None,
lang: "python".to_string(),
Expand Down Expand Up @@ -461,7 +462,7 @@ asyncio.run(invoke())
"".to_string()
}
),
})),
}))),
}),
)
})
Expand Down Expand Up @@ -490,7 +491,7 @@ asyncio.run(invoke())
let model = format!("{}/{}", provider, model);

if let Some(instructions) = instructions {
model_call.push(PdlBlock::Text(TextBlock {
model_call.push(Advanced(Text(TextBlock {
role: Some(Role::System),
text: vec![PdlBlock::String(instructions)],
metadata: Some(
Expand All @@ -500,7 +501,7 @@ asyncio.run(invoke())
.unwrap(),
),
parser: None,
}));
})));
}

let mut model_builder = ModelBlockBuilder::default();
Expand All @@ -523,7 +524,7 @@ asyncio.run(invoke())
}
}

model_call.push(PdlBlock::Model(model_builder.build().unwrap()));
model_call.push(Advanced(Model(model_builder.build().unwrap())));

if let Some(tools) = tools {
if tools.len() > 0 {
Expand All @@ -537,7 +538,7 @@ asyncio.run(invoke())
closure_name.clone(),
PdlBlock::Function(FunctionBlock {
function: HashMap::new(),
return_: Box::new(PdlBlock::Text(TextBlock {
return_: Box::new(Advanced(Text(TextBlock {
metadata: Some(
MetadataBuilder::default()
.description(format!("Model call {}", &model))
Expand All @@ -547,10 +548,10 @@ asyncio.run(invoke())
role: None,
parser: None,
text: model_call,
})),
}))),
}),
);
PdlBlock::Text(TextBlock {
Advanced(Text(TextBlock {
metadata: Some(
MetadataBuilder::default()
.description("Model call wrapper".to_string())
Expand All @@ -560,11 +561,11 @@ asyncio.run(invoke())
),
role: None,
parser: None,
text: vec![PdlBlock::Call(CallBlock::new(format!(
text: vec![Advanced(Call(CallBlock::new(format!(
"${{ {} }}",
closure_name
)))],
})
))))],
}))
},
)
.collect::<Vec<_>>();
Expand All @@ -579,19 +580,19 @@ asyncio.run(invoke())
let mut defs = indexmap::IndexMap::new();
defs.insert(
"pdl__tools".to_string(),
PdlBlock::Object(ObjectBlock {
Advanced(Object(ObjectBlock {
object: tool_declarations,
}),
})),
);
metadata.defs(defs);
}

let pdl: PdlBlock = PdlBlock::Text(TextBlock {
let pdl: PdlBlock = Advanced(Text(TextBlock {
metadata: Some(metadata.build().unwrap()),
role: None,
parser: None,
text: body,
});
}));

Ok(pdl)
}
Expand Down
16 changes: 10 additions & 6 deletions pdl-live-react/src-tauri/src/pdl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ impl SequencingBlock for LastOfBlock {
&self.parser
}
fn to_block(&self) -> PdlBlock {
PdlBlock::LastOf(self.clone())
PdlBlock::Advanced(Block::LastOf(self.clone()))
}
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
match output_results.last() {
Expand Down Expand Up @@ -243,7 +243,7 @@ impl SequencingBlock for TextBlock {
&self.parser
}
fn to_block(&self) -> PdlBlock {
PdlBlock::Text(self.clone())
PdlBlock::Advanced(Block::Text(self.clone()))
}
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
PdlResult::String(
Expand Down Expand Up @@ -596,8 +596,15 @@ pub enum PdlBlock {
Number(Number),
String(String),
Function(FunctionBlock),
Advanced(Block),
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
Empty(EmptyBlock),
}

// the rest have Metadata; TODO refactor to make this more explicit
/// A PDL block that has structure and metadata
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum Block {
If(IfBlock),
Import(ImportBlock),
Include(IncludeBlock),
Expand All @@ -612,9 +619,6 @@ pub enum PdlBlock {
Model(ModelBlock),
LastOf(LastOfBlock),
Text(TextBlock),

// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
Empty(EmptyBlock),
}

impl From<bool> for PdlBlock {
Expand Down
35 changes: 17 additions & 18 deletions pdl-live-react/src-tauri/src/pdl/extract.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::pdl::ast::{Metadata, PdlBlock};
use crate::pdl::ast::{Block::*, Metadata, PdlBlock, PdlBlock::Advanced};

/// Extract models referenced by the programs
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
Expand All @@ -20,18 +20,26 @@ pub fn extract_values(program: &PdlBlock, field: &str) -> Vec<String> {
/// Take one Yaml fragment and produce a vector of the string-valued entries of the given field
fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>) {
match program {
PdlBlock::Model(b) => values.push(b.model.clone()),
PdlBlock::Repeat(b) => {
PdlBlock::Empty(b) => {
b.defs
.values()
.for_each(|p| extract_values_iter(p, field, values));
}
PdlBlock::Function(b) => {
extract_values_iter(&b.return_, field, values);
}
Advanced(Model(b)) => values.push(b.model.clone()),
Advanced(Repeat(b)) => {
extract_values_iter(&b.repeat, field, values);
}
PdlBlock::Message(b) => {
Advanced(Message(b)) => {
extract_values_iter(&b.content, field, values);
}
PdlBlock::Array(b) => b
Advanced(Array(b)) => b
.array
.iter()
.for_each(|p| extract_values_iter(p, field, values)),
PdlBlock::Text(b) => {
Advanced(Text(b)) => {
b.text
.iter()
.for_each(|p| extract_values_iter(p, field, values));
Expand All @@ -43,7 +51,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
.for_each(|p| extract_values_iter(p, field, values));
}
}
PdlBlock::LastOf(b) => {
Advanced(LastOf(b)) => {
b.last_of
.iter()
.for_each(|p| extract_values_iter(p, field, values));
Expand All @@ -55,7 +63,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
.for_each(|p| extract_values_iter(p, field, values));
}
}
PdlBlock::If(b) => {
Advanced(If(b)) => {
extract_values_iter(&b.then, field, values);
if let Some(else_) = &b.else_ {
extract_values_iter(else_, field, values);
Expand All @@ -68,20 +76,11 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
.for_each(|p| extract_values_iter(p, field, values));
}
}
PdlBlock::Empty(b) => {
b.defs
.values()
.for_each(|p| extract_values_iter(p, field, values));
}
PdlBlock::Object(b) => b
Advanced(Object(b)) => b
.object
.values()
.for_each(|p| extract_values_iter(p, field, values)),

PdlBlock::Function(b) => {
extract_values_iter(&b.return_, field, values);
}

_ => {}
}
}
Loading