Skip to content

Commit d70c52d

Browse files
committed
call block with args
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 26f2751 commit d70c52d

File tree

4 files changed

+93
-34
lines changed

4 files changed

+93
-34
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
1515
PdlBaseType, PdlBlock, PdlCallBlock, PdlFunctionBlock, PdlMessageBlock, PdlModelBlock,
16-
PdlOptionalType, PdlParser, PdlRepeatBlock, PdlTextBlock, PdlType, Role,
16+
PdlObjectBlock, PdlOptionalType, PdlParser, PdlRepeatBlock, PdlTextBlock, PdlType, Role,
1717
};
1818
use crate::pdl::pip::pip_install_if_needed;
1919
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -514,9 +514,9 @@ asyncio.run(invoke())
514514
let mut m = HashMap::new();
515515
m.insert(
516516
"pdl__tools".to_string(),
517-
PdlBlock::Object {
517+
PdlBlock::Object(PdlObjectBlock {
518518
object: tool_declarations,
519-
},
519+
}),
520520
);
521521
Some(m)
522522
},

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use ::std::collections::HashMap;
22
use serde::{Deserialize, Serialize};
3-
use serde_json::{Map, Value};
3+
use serde_json::{Number, Value};
44

55
#[derive(Serialize, Deserialize, Debug, Clone)]
66
#[serde(rename_all_fields = "lowercase")]
@@ -50,7 +50,7 @@ pub enum PdlType {
5050
pub struct PdlCallBlock {
5151
pub call: String,
5252
#[serde(skip_serializing_if = "Option::is_none")]
53-
pub args: Option<Box<PdlBlock>>,
53+
pub args: Option<Value>,
5454
#[serde(skip_serializing_if = "Option::is_none")]
5555
pub defs: Option<HashMap<String, PdlBlock>>,
5656
}
@@ -210,16 +210,23 @@ pub struct PdlMessageBlock {
210210
pub tool_call_id: Option<String>,
211211
}
212212

213+
#[derive(Serialize, Deserialize, Debug, Clone)]
214+
pub struct PdlObjectBlock {
215+
pub object: HashMap<String, PdlBlock>,
216+
}
217+
213218
#[derive(Serialize, Deserialize, Debug, Clone)]
214219
#[serde(untagged)]
215220
pub enum PdlBlock {
221+
Bool(bool),
222+
Number(Number),
216223
String(String),
217224
/*If {
218225
#[serde(rename = "if")]
219226
condition: String,
220227
then: Box<PdlBlock>,
221228
},*/
222-
Object { object: HashMap<String, PdlBlock> },
229+
Object(PdlObjectBlock),
223230
Call(PdlCallBlock),
224231
Array { array: Vec<PdlBlock> },
225232
Message(PdlMessageBlock),
@@ -236,6 +243,12 @@ impl From<&str> for PdlBlock {
236243
}
237244
}
238245

246+
impl From<String> for PdlBlock {
247+
fn from(s: String) -> Self {
248+
PdlBlock::String(s.clone())
249+
}
250+
}
251+
239252
impl From<&str> for Box<PdlBlock> {
240253
fn from(s: &str) -> Self {
241254
Box::new(PdlBlock::String(s.into()))

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use ollama_rs::{
1717
Ollama,
1818
};
1919

20-
use serde_json::{from_str, to_string, Value};
20+
use serde_json::{from_str, to_string, Map, Value};
2121
use serde_norway::from_reader;
2222
use tokio::runtime::Runtime;
2323

@@ -104,28 +104,29 @@ impl<'a> Interpreter<'a> {
104104
}
105105

106106
// Evaluate as a Jinja2 expression
107-
fn eval(&self, expr: &String) -> Result<PdlBlock, Box<dyn Error>> {
107+
fn eval<T: serde::de::DeserializeOwned + ::std::convert::From<String>>(
108+
&self,
109+
expr: &String,
110+
) -> Result<T, Box<dyn Error>> {
108111
let result = self
109112
.jinja_env
110113
.render_str(expr.as_str(), self.scope.last().unwrap_or(&HashMap::new()))?;
111114
if self.debug {
112115
eprintln!("Eval '{}' -> {}", &expr, &result);
113116
}
114117

115-
match from_str(&result) {
116-
Err(_) => {
117-
if self.debug {
118-
eprintln!("Plain string {}", &result);
119-
}
120-
Ok(PdlBlock::String(result))
118+
let backup = result.clone();
119+
Ok(from_str(&result).unwrap_or_else(|_err| {
120+
if self.debug {
121+
eprintln!("Plain string {}", &result);
121122
}
122-
Ok(x) => Ok(x),
123-
}
123+
backup.into()
124+
}))
124125
}
125126

126127
// Run a PdlBlock::String
127128
fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
128-
let trace = self.eval(msg)?;
129+
let trace = self.eval::<PdlBlock>(msg)?;
129130
if self.debug {
130131
eprintln!("String {} -> {:?}", msg, trace);
131132
}
@@ -141,12 +142,56 @@ impl<'a> Interpreter<'a> {
141142
// Run a PdlBlock::Call
142143
fn run_call(&mut self, block: &PdlCallBlock, context: Context) -> Interpretation {
143144
if self.debug {
144-
eprintln!("Call {:?} {:?}", block.call, block.args);
145+
eprintln!("Call {:?}({:?})", block.call, block.args);
145146
}
146-
match self.eval(&block.call)? {
147+
148+
let args = match &block.args {
149+
Some(x) => match x {
150+
// args is a string; eval it and see if we get an Object out the other side
151+
Value::String(s) => match self.eval::<Value>(&s)? {
152+
// args was a string that eval'd to an Object
153+
Value::Object(m) => Ok(Some(self.to_pdl(&m))),
154+
// args was a string that eval'd to something we don't understand
155+
y => Err(Box::<dyn Error>::from(format!(
156+
"Invalid arguments to call {:?}",
157+
y
158+
))),
159+
},
160+
// args is already an Object
161+
Value::Object(m) => Ok(Some(self.to_pdl(&m))),
162+
// args is something we don't understand
163+
y => Err(Box::from(format!("Invalid arguments to call {:?}", y))),
164+
},
165+
// no args... that's ok (TODO: check against function schema)
166+
None => Ok(None),
167+
}?;
168+
self.extend_scope_with_map(&args);
169+
170+
let res = match self.eval::<PdlBlock>(&block.call)? {
147171
PdlBlock::Function(f) => self.run(&f.return_, context.clone()),
148172
_ => Err(Box::from("call of non-function")),
149-
}
173+
};
174+
self.scope.pop();
175+
176+
res
177+
}
178+
179+
fn to_pdl(&self, m: &Map<String, Value>) -> HashMap<String, PdlBlock> {
180+
m.into_iter()
181+
.map(|(k, v)| {
182+
(
183+
k.clone(),
184+
match v {
185+
Value::String(s) => PdlBlock::String(s.clone()),
186+
Value::Number(n) => PdlBlock::Number(n.clone()),
187+
x => {
188+
eprintln!("Unhandled arg value {:?}", x);
189+
"error".into()
190+
}
191+
},
192+
)
193+
})
194+
.collect()
150195
}
151196

152197
fn to_ollama_model_options(
@@ -256,7 +301,7 @@ impl<'a> Interpreter<'a> {
256301
let for_ = block
257302
.for_
258303
.iter()
259-
.map(|(var, values)| (var, self.eval(&values)));
304+
.map(|(var, values)| (var, self.eval::<PdlBlock>(&values)));
260305

261306
if self.debug {
262307
eprintln!("Repeat {:?}", &for_);

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,24 @@ mod tests {
7070
assert_eq!(messages[0].role, MessageRole::User);
7171
assert_eq!(messages[0].content, "what is the fastest animal?");
7272
assert_eq!(messages[1].role, MessageRole::Assistant);
73-
assert!(messages[1].content.contains("cheetah"));
73+
let m1 = messages[1].content.to_lowercase();
74+
assert!(m1.contains("cheetah") || m1.contains("springbok"));
7475
assert_eq!(messages[2].role, MessageRole::User);
7576
assert_eq!(messages[2].content, "in europe?");
7677
assert_eq!(messages[3].role, MessageRole::Assistant);
7778

78-
let m = messages[3].content.to_lowercase();
79+
let m3 = messages[3].content.to_lowercase();
7980
assert!(
80-
m.contains("peregrine")
81-
|| m.contains("bison")
82-
|| m.contains("hare")
83-
|| m.contains("golden eagle")
84-
|| m.contains("greyhound")
85-
|| m.contains("gazelle")
86-
|| m.contains("lynx")
87-
|| m.contains("boar")
88-
|| m.contains("sailfish")
89-
|| m.contains("pronghorn")
81+
m3.contains("peregrine")
82+
|| m3.contains("bison")
83+
|| m3.contains("hare")
84+
|| m3.contains("golden eagle")
85+
|| m3.contains("greyhound")
86+
|| m3.contains("gazelle")
87+
|| m3.contains("lynx")
88+
|| m3.contains("boar")
89+
|| m3.contains("sailfish")
90+
|| m3.contains("pronghorn")
9091
);
9192
Ok(())
9293
}
@@ -145,7 +146,7 @@ mod tests {
145146
"x": "int"
146147
},
147148
"return": {
148-
"description": "nullary function",
149+
"description": "unary function",
149150
"text": [
150151
"hello world ${x+1}"
151152
]

0 commit comments

Comments
 (0)