Skip to content

Commit ebd6511

Browse files
committed
eval improvements, text block def, test for this
Signed-off-by: Nick Mitchell <[email protected]>
1 parent e5b56b5 commit ebd6511

File tree

7 files changed

+159
-42
lines changed

7 files changed

+159
-42
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Rust Interpreter Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
# cancel any prior runs for this workflow and this PR (or branch)
10+
concurrency:
11+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
12+
cancel-in-progress: true
13+
14+
jobs:
15+
rust-interpreter:
16+
name: Test Rust interpreter
17+
runs-on: ubuntu-latest
18+
defaults:
19+
run:
20+
working-directory: ./pdl-live-react
21+
steps:
22+
- uses: actions/checkout@v4
23+
- name: Set up node
24+
uses: actions/setup-node@v4
25+
with:
26+
node-version: 22
27+
- name: Install dependencies
28+
# sleep 2 to wait for ollama to be running... hack warning
29+
run: |
30+
npm ci & sudo apt update && sudo apt install -y libgtk-3-dev libwebkit2gtk-4.1-dev librsvg2-dev patchelf at-spi2-core &
31+
(curl -fsSL https://ollama.com/install.sh | sudo -E sh && sleep 2)
32+
wait
33+
# todo: do this in rust
34+
ollama pull granite3.2:2b
35+
- name: Run interpreter tests
36+
run: npm run test:interpreter

.github/workflows/tauri-cli.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
for i in ./demos/beeai/*.py
4646
do pdl compile beeai $i -g -o /tmp/z.json && jq .description /tmp/z.json
4747
done
48-
48+
4949
- name: Test pdl run against production build
5050
env:
5151
DISPLAY: :1

pdl-live-react/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"test:bee": "until [ -f ./src-tauri/target/debug/pdl ]; do sleep 1; done; for i in ./demos/beeai/*.py; do ./src-tauri/target/debug/pdl compile beeai $i -g --output - | jq; done",
1818
"test:interpreter": "cd src-tauri && cargo test",
1919
"types": "(cd .. && python -m src.pdl.pdl --schema > src/pdl/pdl-schema.json) && json2ts ../src/pdl/pdl-schema.json src/pdl_ast.d.ts --unreachableDefinitions && npm run format",
20-
"test": "concurrently -n 'quality,playwright,interpreter' 'npm run test:quality' 'npm run test:ui' 'npm run test:interpreter'",
20+
"test": "concurrently -n 'quality,playwright' 'npm run test:quality' 'npm run test:ui'",
2121
"pdl": "./src-tauri/target/debug/pdl",
2222
"view": "npm run pdl view",
2323
"start": "npm run tauri dev"

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ fn with_tools(
167167

168168
fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
169169
let repeat = PdlBlock::Text(PdlTextBlock {
170+
def: None,
170171
defs: None,
171172
role: None,
172173
parser: None,
@@ -437,6 +438,7 @@ asyncio.run(invoke())
437438
model_call.push(PdlBlock::Text(PdlTextBlock {
438439
role: Some(Role::System),
439440
text: vec![PdlBlock::String(instructions)],
441+
def: None,
440442
defs: None,
441443
parser: None,
442444
description: Some("Model instructions".into()),
@@ -476,6 +478,7 @@ asyncio.run(invoke())
476478
PdlBlock::Function(PdlFunctionBlock {
477479
function: HashMap::new(),
478480
return_: Box::new(PdlBlock::Text(PdlTextBlock {
481+
def: None,
479482
defs: None,
480483
role: None,
481484
parser: None,
@@ -485,15 +488,15 @@ asyncio.run(invoke())
485488
}),
486489
);
487490
PdlBlock::Text(PdlTextBlock {
491+
def: None,
488492
defs: Some(defs),
489493
role: None,
490494
parser: None,
491495
description: Some("Model call wrapper".to_string()),
492-
text: vec![PdlBlock::Call(PdlCallBlock {
493-
call: format!("${{ {} }}", closure_name),
494-
defs: None,
495-
args: None,
496-
})],
496+
text: vec![PdlBlock::Call(PdlCallBlock::new(format!(
497+
"${{ {} }}",
498+
closure_name
499+
)))],
497500
})
498501
},
499502
)
@@ -504,6 +507,7 @@ asyncio.run(invoke())
504507
.collect::<Vec<_>>();
505508

506509
let pdl: PdlBlock = PdlBlock::Text(PdlTextBlock {
510+
def: None,
507511
defs: if tool_declarations.len() == 0 {
508512
None
509513
} else {

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ pub struct PdlCallBlock {
5555
pub defs: Option<HashMap<String, PdlBlock>>,
5656
}
5757

58+
impl PdlCallBlock {
59+
pub fn new(call: String) -> Self {
60+
PdlCallBlock {
61+
call: call,
62+
args: None,
63+
defs: None,
64+
}
65+
}
66+
}
67+
5868
#[derive(Serialize, Deserialize, Debug, Clone)]
5969
pub struct PdlTextBlock {
6070
#[serde(skip_serializing_if = "Option::is_none")]
@@ -66,11 +76,14 @@ pub struct PdlTextBlock {
6676
pub defs: Option<HashMap<String, PdlBlock>>,
6777
#[serde(skip_serializing_if = "Option::is_none")]
6878
pub parser: Option<PdlParser>,
79+
#[serde(skip_serializing_if = "Option::is_none")]
80+
pub def: Option<String>,
6981
}
7082

7183
impl PdlTextBlock {
7284
pub fn new(text: Vec<PdlBlock>) -> Self {
7385
PdlTextBlock {
86+
def: None,
7487
defs: None,
7588
description: None,
7689
role: None,
@@ -79,6 +92,11 @@ impl PdlTextBlock {
7992
}
8093
}
8194

95+
pub fn def(&mut self, def: &str) -> &mut Self {
96+
self.def = Some(def.into());
97+
self
98+
}
99+
82100
pub fn description(&mut self, description: String) -> &mut Self {
83101
self.description = Some(description);
84102
self

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ use ::std::error::Error;
55
use ::std::fs::File;
66
// use ::std::path::PathBuf;
77

8-
use minijinja::value::ValueKind;
98
use minijinja::Environment;
9+
1010
use ollama_rs::{
1111
generation::{
1212
chat::{request::ChatMessageRequest, ChatMessage, MessageRole},
@@ -15,37 +15,38 @@ use ollama_rs::{
1515
models::ModelOptions,
1616
Ollama,
1717
};
18+
1819
use serde_json::{from_str, to_string, Value};
1920
use serde_norway::from_reader;
2021
use tokio::runtime::Runtime;
2122

2223
use crate::pdl::ast::{
23-
PdlBlock, PdlCallBlock, PdlModelBlock, PdlRepeatBlock, PdlTextBlock, PdlUsage, Role,
24+
PdlBlock, PdlCallBlock, PdlModelBlock, PdlParser, PdlRepeatBlock, PdlTextBlock, PdlUsage, Role,
2425
};
2526

2627
type Context = Vec<ChatMessage>;
27-
type Scope = HashMap<String, String>;
28+
type Scope = HashMap<String, Value>;
2829
type Interpretation = Result<(Context, PdlBlock), Box<dyn Error>>;
2930

30-
struct Interpreter<'a> {
31+
struct Interpreter {
3132
// batch: u32,
3233
// role: Role,
3334
// cwd: Box<PathBuf>,
3435
// id_stack: Vec<String>,
35-
jinja_env: Environment<'a>,
36+
// jinja_env: Environment<'a>,
3637
rt: Runtime,
3738
scope: Vec<Scope>,
3839
debug: bool,
3940
}
4041

41-
impl<'a> Interpreter<'a> {
42+
impl Interpreter {
4243
fn new() -> Self {
4344
Self {
4445
// batch: 0,
4546
// role: Role::User,
4647
// cwd: Box::new(current_dir().unwrap_or(PathBuf::from("/"))),
4748
// id_stack: vec![],
48-
jinja_env: Environment::new(),
49+
// jinja_env: Environment::new(),
4950
rt: Runtime::new().unwrap(),
5051
scope: vec![Scope::new()],
5152
debug: false,
@@ -66,26 +67,41 @@ impl<'a> Interpreter<'a> {
6667
Ok((messages, trace))
6768
}
6869

69-
fn eval(&self, expr: &String) -> Result<PdlBlock, Box<dyn Error>> {
70-
let e = expr.replace("${ ", "").replace(" }", "");
71-
let jexpr = self.jinja_env.compile_expression(&e)?;
72-
let result = jexpr.eval(self.scope.last().unwrap_or(&HashMap::new()))?;
70+
fn eval(&self, pdl_expr: &String) -> Result<PdlBlock, Box<dyn Error>> {
71+
let expr = pdl_expr.replace("${ ", "{{").replace(" }", "}}"); // FIXME regexp
72+
let mut env = Environment::new();
73+
env.add_template(pdl_expr, expr.as_str())?;
74+
let tmpl = env.get_template(pdl_expr)?;
75+
let result = tmpl.render(self.scope.last().unwrap_or(&HashMap::new()))?;
76+
eprintln!("Eval '{}' -> {}", &expr, &result);
77+
78+
match from_str(&result) {
79+
Err(_) => {
80+
eprintln!("Plain string {}", &result);
81+
Ok(PdlBlock::String(result))
82+
}
83+
Ok(x) => Ok(x),
84+
}
85+
//let jexpr = self.jinja_env.compile_expression(&e)?;
86+
//let result = jexpr.eval(self.scope.last().unwrap_or(&HashMap::new()))?;
7387
//dbg!("Eval '{}' -> {:?}", e, result);
7488

75-
match result.kind() {
76-
ValueKind::String => Ok(from_str(&result.as_str().unwrap())?),
77-
t => Err(Box::from(format!("Unexpected jinja result type {}", t))),
78-
}
89+
//match result.kind() {
90+
// ValueKind::String => Ok(from_str(&result.as_str().unwrap())?),
91+
// t => Err(Box::from(format!("Unexpected jinja result type: {} -> {}. {:?}", expr, t, result))),
92+
// }
7993
}
8094

8195
fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
96+
let trace = self.eval(msg)?;
8297
if self.debug {
83-
eprintln!("String {}", msg);
98+
eprintln!("String {} -> {:?}", msg, trace);
8499
}
85-
//TODO? self.eval(msg)
86100

87-
let messages = vec![ChatMessage::user(msg.clone())];
88-
let trace = PdlBlock::String(msg.to_string());
101+
let messages = vec![ChatMessage::user(match &trace {
102+
PdlBlock::String(s) => s.clone(),
103+
x => to_string(&x)?,
104+
})];
89105

90106
Ok((messages, trace))
91107
}
@@ -222,6 +238,12 @@ impl<'a> Interpreter<'a> {
222238
}
223239
}
224240

241+
fn parse_result(&self, parser: &PdlParser, result: &String) -> Result<Value, Box<dyn Error>> {
242+
match parser {
243+
PdlParser::Json => Ok(from_str(result)?),
244+
}
245+
}
246+
225247
fn run_text(&mut self, block: &PdlTextBlock, context: Context) -> Interpretation {
226248
if self.debug {
227249
eprintln!(
@@ -237,10 +259,12 @@ impl<'a> Interpreter<'a> {
237259
Some(defs) => {
238260
// this is all non-optimal
239261
let mut scope: Scope = HashMap::from(cur_scope);
240-
scope.extend(
241-
defs.iter()
242-
.map(|(var, def)| (var.clone(), to_string(def).unwrap())),
243-
);
262+
scope.extend(defs.iter().map(|(var, def)| {
263+
(
264+
var.clone(),
265+
from_str(to_string(def).unwrap().as_str()).unwrap(),
266+
)
267+
}));
244268
scope
245269
}
246270
None => cur_scope,
@@ -264,16 +288,28 @@ impl<'a> Interpreter<'a> {
264288
let mut trace = block.clone();
265289
trace.text = output_blocks;
266290

291+
let result_string = output_messages
292+
.iter()
293+
.map(|m| m.content.clone())
294+
.collect::<Vec<_>>()
295+
.join("\n");
296+
297+
if let Some(def) = &block.def {
298+
let result = if let Some(parser) = &block.parser {
299+
self.parse_result(parser, &result_string)?
300+
} else {
301+
Value::from(result_string.clone()) // TODO
302+
};
303+
304+
if let Some(scope) = self.scope.last_mut() {
305+
eprintln!("DEF {} -> {}", def, result);
306+
scope.insert(def.clone(), result);
307+
}
308+
}
309+
267310
Ok((
268311
match &block.role {
269-
Some(role) => vec![ChatMessage::new(
270-
self.to_ollama_role(role),
271-
output_messages
272-
.into_iter()
273-
.map(|m| m.content)
274-
.collect::<Vec<_>>()
275-
.join("\n"),
276-
)],
312+
Some(role) => vec![ChatMessage::new(self.to_ollama_role(role), result_string)],
277313
None => output_messages,
278314
},
279315
PdlBlock::Text(trace),

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@ mod tests {
44
use ::std::error::Error;
55

66
use crate::pdl::{
7-
ast::{
8-
PdlBlock, /*PdlCallBlock,*/
9-
PdlModelBlock, /*PdlRepeatBlock,*/ /*PdlTextBlock,*/ /*, PdlUsage, Role,*/
10-
},
7+
ast::{PdlBlock, PdlModelBlock, PdlParser, PdlTextBlock},
118
interpreter::run,
129
};
1310

@@ -55,4 +52,30 @@ mod tests {
5552
assert!(messages[1].content.contains("Hello!"));
5653
Ok(())
5754
}
55+
56+
#[test]
57+
fn text_parser_json() -> Result<(), Box<dyn Error>> {
58+
let json = "{\"key\": \"value\"}";
59+
let program = PdlBlock::Text(
60+
vec![
61+
PdlBlock::Text(
62+
PdlTextBlock::new(vec![json.into()])
63+
.def(&"foo")
64+
.parser(PdlParser::Json)
65+
.build(),
66+
),
67+
"${ foo.key }".into(),
68+
]
69+
.into(),
70+
);
71+
println!("{}", serde_json::to_string(&program)?);
72+
73+
let (messages, _) = run(&program, false)?;
74+
assert_eq!(messages.len(), 2);
75+
assert_eq!(messages[0].role, MessageRole::User);
76+
assert_eq!(messages[0].content, json);
77+
assert_eq!(messages[1].role, MessageRole::User);
78+
assert_eq!(messages[1].content, "value");
79+
Ok(())
80+
}
5881
}

0 commit comments

Comments
 (0)