Skip to content

Commit b0cd0d4

Browse files
committed
repeat blocks
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 18dbb0f commit b0cd0d4

File tree

13 files changed

+241
-63
lines changed

13 files changed

+241
-63
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# emacs temp files
22
*~
33
\#*\#
4-
.\#*
4+
.\#*
5+
tests/**/*.txt
6+
tests/**/*.pdl

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use serde_json::{from_reader, json, to_string, Map, Value};
1212
use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
15-
PdlBaseType, PdlBlock, PdlCallBlock, PdlFunctionBlock, PdlMessageBlock, PdlModelBlock,
16-
PdlObjectBlock, PdlOptionalType, PdlParser, PdlPythonCodeBlock, PdlRepeatBlock, PdlTextBlock,
17-
PdlType, Role,
15+
PdlBaseType, PdlBlock, PdlCallBlock, PdlFunctionBlock, PdlListOrString, PdlMessageBlock,
16+
PdlModelBlock, PdlObjectBlock, PdlOptionalType, PdlParser, PdlPythonCodeBlock, PdlRepeatBlock,
17+
PdlTextBlock, PdlType, Role,
1818
};
1919
use crate::pdl::pip::pip_install_if_needed;
2020
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -58,7 +58,7 @@ struct BeeAiToolState {
5858
name: String,
5959
description: Option<String>,
6060
input_schema: BeeAiToolSchema,
61-
options: Option<HashMap<String, Value>>,
61+
// options: Option<HashMap<String, Value>>,
6262
}
6363
#[derive(Deserialize, Debug)]
6464
struct BeeAiTool {
@@ -223,7 +223,7 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
223223
let mut for_ = HashMap::new();
224224
for_.insert(
225225
"tool".to_string(),
226-
"${ response.choices[0].message.tool_calls }".to_string(),
226+
PdlListOrString::String("${ response.choices[0].message.tool_calls }".to_string()),
227227
);
228228

229229
// response.choices[0].message.tool_calls

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,17 @@ impl PdlModelBlock {
197197
}
198198
}
199199

200+
#[derive(Serialize, Deserialize, Debug, Clone)]
201+
#[serde(untagged)]
202+
pub enum PdlListOrString {
203+
String(String),
204+
List(Vec<Value>),
205+
}
206+
200207
#[derive(Serialize, Deserialize, Debug, Clone)]
201208
pub struct PdlRepeatBlock {
202209
#[serde(rename = "for")]
203-
pub for_: HashMap<String, String>,
210+
pub for_: HashMap<String, PdlListOrString>,
204211
pub repeat: Box<PdlBlock>,
205212
}
206213

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ use serde_json::{from_str, to_string, Map, Value};
2525
use serde_norway::{from_reader, from_str as from_yaml_str};
2626

2727
use crate::pdl::ast::{
28-
PdlBlock, PdlCallBlock, PdlModelBlock, PdlParser, PdlPythonCodeBlock, PdlReadBlock,
29-
PdlRepeatBlock, PdlTextBlock, PdlUsage, Role,
28+
PdlBlock, PdlCallBlock, PdlListOrString, PdlModelBlock, PdlParser, PdlPythonCodeBlock,
29+
PdlReadBlock, PdlRepeatBlock, PdlTextBlock, PdlUsage, Role,
3030
};
3131

3232
type Context = Vec<ChatMessage>;
@@ -112,7 +112,7 @@ impl<'a> Interpreter<'a> {
112112
self.run_with_emit(program, context, self.emit).await
113113
}
114114

115-
// Evaluate as a Jinja2 expression
115+
// Evaluate String as a Jinja2 expression
116116
fn eval<T: serde::de::DeserializeOwned + ::std::convert::From<String>>(
117117
&self,
118118
expr: &String,
@@ -134,6 +134,43 @@ impl<'a> Interpreter<'a> {
134134
}))
135135
}
136136

137+
fn eval_complex(&self, expr: &Value) -> Result<Value, Box<dyn Error + Send + Sync>> {
138+
match expr {
139+
Value::String(s) => self.eval(s),
140+
Value::Array(a) => Ok(Value::Array(
141+
a.iter()
142+
.map(|v| self.eval_complex(v))
143+
.collect::<Result<_, _>>()?,
144+
)),
145+
Value::Object(o) => Ok(Value::Object(
146+
o.iter()
147+
.map(|(k, v)| match self.eval_complex(v) {
148+
Ok(v) => Ok((k.clone(), v)),
149+
Err(e) => Err(e),
150+
})
151+
.collect::<Result<_, _>>()?,
152+
)),
153+
v => Ok(v.clone()),
154+
}
155+
}
156+
157+
// Evaluate an string or list of Values into a list of Values
158+
fn eval_list_or_string(
159+
&self,
160+
expr: &PdlListOrString,
161+
) -> Result<Vec<Value>, Box<dyn Error + Send + Sync>> {
162+
match expr {
163+
PdlListOrString::String(s) => match self.eval::<Value>(s)? {
164+
Value::Array(a) => Ok(a),
165+
x => Err(Box::from(format!(
166+
"Jinja string expanded to non-list. {} -> {:?}",
167+
s, x
168+
))),
169+
},
170+
PdlListOrString::List(l) => l.iter().map(|v| self.eval_complex(v)).collect(),
171+
}
172+
}
173+
137174
// Run a PdlBlock::String
138175
async fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
139176
let trace = self.eval::<PdlBlock>(msg)?;
@@ -179,53 +216,26 @@ impl<'a> Interpreter<'a> {
179216
eprintln!("Call {:?}({:?})", block.call, block.args);
180217
}
181218

182-
let args = match &block.args {
183-
Some(x) => match x {
184-
// args is a string; eval it and see if we get an Object out the other side
185-
Value::String(s) => match self.eval::<Value>(&s)? {
186-
// args was a string that eval'd to an Object
187-
Value::Object(m) => Ok(Some(self.to_pdl(&m))),
188-
// args was a string that eval'd to something we don't understand
189-
y => Err(Box::<dyn Error + Send + Sync>::from(format!(
190-
"Invalid arguments to call {:?}",
191-
y
192-
))),
193-
},
194-
// args is already an Object
195-
Value::Object(m) => Ok(Some(self.to_pdl(&m))),
196-
// args is something we don't understand
197-
y => Err(Box::from(format!("Invalid arguments to call {:?}", y))),
198-
},
199-
// no args... that's ok (TODO: check against function schema)
200-
None => Ok(None),
201-
}?;
202-
self.extend_scope_with_map(&args);
219+
if let Some(args) = &block.args {
220+
match self.eval_complex(args)? {
221+
Value::Object(m) => Ok(self.extend_scope_with_json_map(m)),
222+
x => Err(Box::<dyn Error + Send + Sync>::from(format!(
223+
"Call arguments not a map: {:?}",
224+
x
225+
))),
226+
}?;
227+
}
203228

204229
let res = match self.eval::<PdlBlock>(&block.call)? {
205230
PdlBlock::Function(f) => self.run(&f.return_, context.clone()).await,
206231
_ => Err(Box::from(format!("call of non-function {:?}", &block.call))),
207232
};
208-
self.scope.pop();
209233

210-
res
211-
}
234+
if let Some(_) = block.args {
235+
self.scope.pop();
236+
}
212237

213-
fn to_pdl(&self, m: &Map<String, Value>) -> HashMap<String, PdlBlock> {
214-
m.into_iter()
215-
.map(|(k, v)| {
216-
(
217-
k.clone(),
218-
match v {
219-
Value::String(s) => PdlBlock::String(s.clone()),
220-
Value::Number(n) => PdlBlock::Number(n.clone()),
221-
x => {
222-
eprintln!("Unhandled arg value {:?}", x);
223-
"error".into()
224-
}
225-
},
226-
)
227-
})
228-
.collect()
238+
res
229239
}
230240

231241
fn to_ollama_model_options(
@@ -428,19 +438,41 @@ impl<'a> Interpreter<'a> {
428438
}
429439

430440
// Run a PdlBlock::Repeat
431-
async fn run_repeat(&mut self, block: &PdlRepeatBlock, _context: Context) -> Interpretation {
432-
let for_ = block
441+
async fn run_repeat(&mut self, block: &PdlRepeatBlock, context: Context) -> Interpretation {
442+
// { i:[1,2,3], j: [4,5,6]} -> ([i,j], [[1,2,3],[4,5,6]])
443+
// let (variables, values): (Vec<_>, Vec<Vec<_>>) = block
444+
// .into_iter()
445+
// .unzip();
446+
let map = block
433447
.for_
434448
.iter()
435-
.map(|(var, values)| (var, self.eval::<PdlBlock>(&values)));
449+
.map(|(var, values)| match self.eval_list_or_string(values) {
450+
Ok(value) => Ok((var.clone(), value)),
451+
Err(e) => Err(e),
452+
})
453+
.collect::<Result<HashMap<_, _>, _>>()?;
436454

437455
if self.debug {
438-
eprintln!("Repeat {:?}", &for_);
456+
eprintln!("Repeat {:?}", map);
439457
}
440-
Ok((
441-
vec![ChatMessage::user("TODO".into())],
442-
PdlBlock::Repeat(block.clone()),
443-
))
458+
459+
let mut messages = vec![];
460+
let mut trace = vec![];
461+
if let Some(n) = map.iter().map(|(_, v)| v.len()).min() {
462+
for iter in 0..n {
463+
let scope: HashMap<String, Value> = map
464+
.iter()
465+
.map(|(k, v)| (k.clone(), v[iter].clone()))
466+
.collect();
467+
self.extend_scope_with_map(scope);
468+
let (ms, t) = self.run_quiet(&block.repeat, context.clone()).await?;
469+
messages.extend(ms);
470+
trace.push(t);
471+
self.scope.pop();
472+
}
473+
}
474+
475+
Ok((messages, PdlBlock::Repeat(block.clone())))
444476
}
445477

446478
fn to_ollama_role(&self, role: &Role) -> MessageRole {
@@ -462,7 +494,18 @@ impl<'a> Interpreter<'a> {
462494
}
463495
}
464496

465-
fn extend_scope_with_map(&mut self, map: &Option<HashMap<String, PdlBlock>>) {
497+
fn extend_scope_with_map(&mut self, new_scope: HashMap<String, Value>) {
498+
self.scope.push(new_scope);
499+
}
500+
501+
fn extend_scope_with_json_map(&mut self, new_scope: Map<String, Value>) {
502+
let mut scope = self.scope.last().unwrap_or(&HashMap::new()).clone();
503+
// TODO figure out iterators
504+
scope.extend(new_scope.into_iter().collect::<HashMap<String, Value>>());
505+
self.extend_scope_with_map(scope);
506+
}
507+
508+
fn extend_scope_with_block_map(&mut self, map: &Option<HashMap<String, PdlBlock>>) {
466509
let cur_scope = self.scope.last().unwrap_or(&HashMap::new()).clone();
467510
let new_scope = match map {
468511
Some(defs) => {
@@ -479,7 +522,7 @@ impl<'a> Interpreter<'a> {
479522
None => cur_scope,
480523
};
481524

482-
self.scope.push(new_scope);
525+
self.extend_scope_with_map(new_scope);
483526
}
484527

485528
// Run a PdlBlock::Text
@@ -498,7 +541,7 @@ impl<'a> Interpreter<'a> {
498541
let mut output_messages = vec![];
499542
let mut output_blocks = vec![];
500543

501-
self.extend_scope_with_map(&block.defs);
544+
self.extend_scope_with_block_map(&block.defs);
502545
let mut iter = block.text.iter();
503546
while let Some(block) = iter.next() {
504547
// run each element of the Text block

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ mod tests {
55
use serde_json::json;
66

77
use crate::pdl::{
8-
ast::{PdlBlock, PdlModelBlock, PdlParser, PdlTextBlock},
8+
ast::{PdlBlock, PdlModelBlock},
99
interpreter::{run_json_sync as run_json, run_sync as run},
1010
};
1111

12-
use ollama_rs::generation::chat::MessageRole;
12+
use ollama_rs::generation::chat::{ChatMessage, MessageRole};
1313

1414
const DEFAULT_MODEL: &'static str = "ollama/granite3.2:2b";
1515

@@ -220,4 +220,78 @@ mod tests {
220220
assert_eq!(messages[0].content, "this should be foo\n");
221221
Ok(())
222222
}
223+
224+
#[test]
225+
fn text_repeat_numbers_1d() -> Result<(), Box<dyn Error>> {
226+
let program = json!({
227+
"for": {
228+
"x": [1,2,3]
229+
},
230+
"repeat": {
231+
"text": [
232+
"${ x + 1 }"
233+
]
234+
}
235+
});
236+
237+
let (messages, _) = run_json(program, false)?;
238+
assert_eq!(messages.len(), 3);
239+
assert_eq!(messages[0].role, MessageRole::User);
240+
assert_eq!(messages[0].content, "2");
241+
assert_eq!(messages[1].role, MessageRole::User);
242+
assert_eq!(messages[1].content, "3");
243+
assert_eq!(messages[2].role, MessageRole::User);
244+
assert_eq!(messages[2].content, "4");
245+
Ok(())
246+
}
247+
248+
#[test]
249+
fn text_repeat_numbers_2d() -> Result<(), Box<dyn Error>> {
250+
let program = json!({
251+
"for": {
252+
"x": [1,2,3],
253+
"y": [4,5,6]
254+
},
255+
"repeat": {
256+
"text": [
257+
"${ x + y }"
258+
]
259+
}
260+
});
261+
262+
let (messages, _) = run_json(program, false)?;
263+
assert_eq!(messages.len(), 3);
264+
assert_eq!(messages[0].role, MessageRole::User);
265+
assert_eq!(messages[0].content, "5");
266+
assert_eq!(messages[1].role, MessageRole::User);
267+
assert_eq!(messages[1].content, "7");
268+
assert_eq!(messages[2].role, MessageRole::User);
269+
assert_eq!(messages[2].content, "9");
270+
Ok(())
271+
}
272+
273+
#[test]
274+
fn text_repeat_mix_2d() -> Result<(), Box<dyn Error>> {
275+
let program = json!({
276+
"for": {
277+
"x": [{"z": 4}, {"z": 5}, {"z": 6}],
278+
"y": ["a","b","c"]
279+
},
280+
"repeat": {
281+
"text": [
282+
"${ x.z ~ y }" // ~ is string concatenation in jinja
283+
]
284+
}
285+
});
286+
287+
let (messages, _) = run_json(program, false)?;
288+
assert_eq!(messages.len(), 3);
289+
assert_eq!(messages[0].role, MessageRole::User);
290+
assert_eq!(messages[0].content, "4a");
291+
assert_eq!(messages[1].role, MessageRole::User);
292+
assert_eq!(messages[1].content, "5b");
293+
assert_eq!(messages[2].role, MessageRole::User);
294+
assert_eq!(messages[2].content, "6c");
295+
Ok(())
296+
}
223297
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
defs:
2+
foo:
3+
function: {}
4+
return:
5+
description: nullary function
6+
text:
7+
- hello world
8+
text:
9+
- call: ${ foo }
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defs:
2+
foo:
3+
function:
4+
x: int
5+
return:
6+
description: nullary function
7+
text:
8+
- hello world ${x+1}
9+
text:
10+
- call: ${ foo }
11+
args:
12+
x: 3
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
lang: python
2+
code: 'print(''hi ho''); result = {"foo": 3}'

0 commit comments

Comments
 (0)