Skip to content

Commit 485089a

Browse files
committed
some call args work, plus proper minijinja variable delimieters
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 09cc3d9 commit 485089a

File tree

6 files changed

+89
-32
lines changed

6 files changed

+89
-32
lines changed

pdl-live-react/src-tauri/Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pdl-live-react/src-tauri/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ sha2 = "0.10.8"
3434
base64ct = { version = "1.7.1", features = ["alloc"] }
3535
dirs = "6.0.0"
3636
serde_norway = "0.9.42"
37-
minijinja = "2.8.0"
37+
minijinja = { version = "2.9.0", features = ["custom_syntax"] }
3838
ollama-rs = { version = "0.3.0", features = ["tokio"] }
3939
tokio = "1.44.1"
4040
owo-colors = "4.2.0"

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
188188
&"${ tool.function.arguments }",
189189
),
190190
call: "${ pdl__tools[tool.function.name] }".to_string(), // look up tool in tool_declarations def (see below)
191-
args: Some("${ args }".to_string()), // invoke with arguments as specified by the model
191+
args: Some("${ args }".into()), // invoke with arguments as specified by the model
192192
})),
193193
})],
194194
})

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use ::std::collections::HashMap;
22
use serde::{Deserialize, Serialize};
3-
use serde_json::Value;
3+
use serde_json::{Map, Value};
44

55
#[derive(Serialize, Deserialize, Debug, Clone)]
66
#[serde(rename_all_fields = "lowercase")]
@@ -50,7 +50,7 @@ pub enum PdlType {
5050
pub struct PdlCallBlock {
5151
pub call: String,
5252
#[serde(skip_serializing_if = "Option::is_none")]
53-
pub args: Option<String>,
53+
pub args: Option<Box<PdlBlock>>,
5454
#[serde(skip_serializing_if = "Option::is_none")]
5555
pub defs: Option<HashMap<String, PdlBlock>>,
5656
}
@@ -235,3 +235,9 @@ impl From<&str> for PdlBlock {
235235
PdlBlock::String(s.into())
236236
}
237237
}
238+
239+
impl From<&str> for Box<PdlBlock> {
240+
fn from(s: &str) -> Self {
241+
Box::new(PdlBlock::String(s.into()))
242+
}
243+
}

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use ::std::error::Error;
55
use ::std::fs::File;
66
// use ::std::path::PathBuf;
77

8-
use minijinja::Environment;
9-
use owo_colors::{FgColorDisplay, OwoColorize};
8+
use minijinja::{syntax::SyntaxConfig, Environment};
9+
use owo_colors::OwoColorize;
1010

1111
use ollama_rs::{
1212
generation::{
@@ -68,11 +68,17 @@ impl Interpreter {
6868
Ok((messages, trace))
6969
}
7070

71-
fn eval(&self, pdl_expr: &String) -> Result<PdlBlock, Box<dyn Error>> {
72-
let expr = pdl_expr.replace("${ ", "{{").replace(" }", "}}"); // FIXME regexp
71+
// Evaluate as a Jinja2 expression
72+
fn eval(&self, expr: &String) -> Result<PdlBlock, Box<dyn Error>> {
7373
let mut env = Environment::new();
74-
env.add_template(pdl_expr, expr.as_str())?;
75-
let tmpl = env.get_template(pdl_expr)?;
74+
// PDL uses custom variable delimeters, because {{ }} have pre-defined meaning in yaml
75+
env.set_syntax(
76+
SyntaxConfig::builder()
77+
.variable_delimiters("${", "}")
78+
.build()?,
79+
);
80+
env.add_template(expr, expr.as_str())?;
81+
let tmpl = env.get_template(expr)?;
7682
let result = tmpl.render(self.scope.last().unwrap_or(&HashMap::new()))?;
7783
if self.debug {
7884
eprintln!("Eval '{}' -> {}", &expr, &result);
@@ -89,6 +95,7 @@ impl Interpreter {
8995
}
9096
}
9197

98+
// Run a PdlBlock::String
9299
fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
93100
let trace = self.eval(msg)?;
94101
if self.debug {
@@ -103,9 +110,10 @@ impl Interpreter {
103110
Ok((messages, trace))
104111
}
105112

113+
// Run a PdlBlock::Call
106114
fn run_call(&mut self, block: &PdlCallBlock, context: Context) -> Interpretation {
107115
if self.debug {
108-
eprintln!("Call {:?}", &block.call);
116+
eprintln!("Call {:?} {:?}", block.call, block.args);
109117
}
110118
match self.eval(&block.call)? {
111119
PdlBlock::Function(f) => self.run(&f.return_, context.clone()),
@@ -146,6 +154,7 @@ impl Interpreter {
146154
}
147155
}
148156

157+
// Run a PdlBlock::Model
149158
fn run_model(&mut self, block: &PdlModelBlock, context: Context) -> Interpretation {
150159
match &block.model {
151160
pdl_model
@@ -159,6 +168,7 @@ impl Interpreter {
159168
};
160169

161170
let (options, tools) = self.to_ollama_model_options(&block.parameters);
171+
println!("MODEL OPTIONS {:?}", options);
162172

163173
let messages = match &block.input {
164174
Some(input) => {
@@ -211,6 +221,7 @@ impl Interpreter {
211221
}
212222
}
213223

224+
// Run a PdlBlock::Repeat
214225
fn run_repeat(&mut self, block: &PdlRepeatBlock, _context: Context) -> Interpretation {
215226
let for_ = block
216227
.for_
@@ -241,18 +252,9 @@ impl Interpreter {
241252
}
242253
}
243254

244-
fn run_text(&mut self, block: &PdlTextBlock, context: Context) -> Interpretation {
245-
if self.debug {
246-
eprintln!(
247-
"Text {:?}",
248-
block
249-
.description
250-
.clone()
251-
.unwrap_or("<no description>".to_string())
252-
);
253-
}
255+
fn extend_scope_with_map(&mut self, map: &Option<HashMap<String, PdlBlock>>) {
254256
let cur_scope = self.scope.last().unwrap_or(&HashMap::new()).clone();
255-
let new_scope: Scope = match &block.defs {
257+
let new_scope = match map {
256258
Some(defs) => {
257259
// this is all non-optimal
258260
let mut scope: Scope = HashMap::from(cur_scope);
@@ -267,11 +269,26 @@ impl Interpreter {
267269
None => cur_scope,
268270
};
269271

272+
self.scope.push(new_scope);
273+
}
274+
275+
// Run a PdlBlock::Text
276+
fn run_text(&mut self, block: &PdlTextBlock, context: Context) -> Interpretation {
277+
if self.debug {
278+
eprintln!(
279+
"Text {:?}",
280+
block
281+
.description
282+
.clone()
283+
.unwrap_or("<no description>".to_string())
284+
);
285+
}
286+
270287
let mut input_messages = context.clone();
271288
let mut output_messages = vec![];
272289
let mut output_blocks = vec![];
273290

274-
self.scope.push(new_scope);
291+
self.extend_scope_with_map(&block.defs);
275292
block.text.iter().try_for_each(|block| {
276293
// run each element of the Text block
277294
let (this_messages, trace) = self.run(&block, input_messages.clone())?;

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,19 @@ mod tests {
7474
assert_eq!(messages[2].role, MessageRole::User);
7575
assert_eq!(messages[2].content, "in europe?");
7676
assert_eq!(messages[3].role, MessageRole::Assistant);
77+
78+
let m = messages[3].content.to_lowercase();
7779
assert!(
78-
messages[3].content.contains("peregrine")
79-
|| messages[3].content.contains("bison")
80-
|| messages[3].content.contains("hare")
81-
|| messages[3].content.contains("Eagle")
82-
|| messages[3].content.contains("Greyhound")
83-
|| messages[3].content.contains("gazelle")
84-
|| messages[3].content.contains("lynx")
80+
m.contains("peregrine")
81+
|| m.contains("bison")
82+
|| m.contains("hare")
83+
|| m.contains("golden eagle")
84+
|| m.contains("greyhound")
85+
|| m.contains("gazelle")
86+
|| m.contains("lynx")
87+
|| m.contains("boar")
88+
|| m.contains("sailfish")
89+
|| m.contains("pronghorn")
8590
);
8691
Ok(())
8792
}
@@ -130,4 +135,32 @@ mod tests {
130135
assert_eq!(messages[0].content, "hello world");
131136
Ok(())
132137
}
138+
139+
#[test]
140+
fn text_call_function_with_args() -> Result<(), Box<dyn Error>> {
141+
let program = json!({
142+
"defs": {
143+
"foo": {
144+
"function": {
145+
"x": "int"
146+
},
147+
"return": {
148+
"description": "nullary function",
149+
"text": [
150+
"hello world ${x+1}"
151+
]
152+
}
153+
}
154+
},
155+
"text": [
156+
{ "call": "${ foo }", "args": { "x": 3 } },
157+
]
158+
});
159+
160+
let (messages, _) = run_json(program, false)?;
161+
assert_eq!(messages.len(), 1);
162+
assert_eq!(messages[0].role, MessageRole::User);
163+
assert_eq!(messages[0].content, "hello world 4");
164+
Ok(())
165+
}
133166
}

0 commit comments

Comments
 (0)