Skip to content

Commit 8353572

Browse files
committed
datablock support
Signed-off-by: Nick Mitchell <[email protected]>
1 parent fb20c88 commit 8353572

File tree

6 files changed

+170
-21
lines changed

6 files changed

+170
-21
lines changed

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,51 @@ pub struct MessageBlock {
381381
pub tool_call_id: Option<String>,
382382
}
383383

384+
/// Return the object where the value of each field is defined by a
385+
/// block. If the body of the object is an array, the resulting object
386+
/// is the union of the objects computed by each element of the array.
384387
#[derive(Serialize, Deserialize, Debug, Clone)]
385388
pub struct ObjectBlock {
386389
pub object: HashMap<String, PdlBlock>,
387390
}
388391

392+
/// Arbitrary value, equivalent to JSON.
393+
///
394+
/// Example. As part of a `defs` section, set `numbers` to the list `[1, 2, 3, 4]`:
395+
/// ```PDL
396+
/// defs:
397+
/// numbers:
398+
/// data: [1, 2, 3, 4]
399+
/// ```
400+
///
401+
/// Example. Evaluate `${ TEST.answer }` in
402+
/// [Jinja](https://jinja.palletsprojects.com/en/stable/), passing
403+
/// the result to a regex parser with capture groups. Set
404+
/// `EXTRACTED_GROUND_TRUTH` to an object with attribute `answer`,
405+
/// a string, containing the value of the capture group.
406+
/// ```PDL
407+
/// - data: ${ TEST.answer }
408+
/// parser:
409+
/// regex: "(.|\\n)*#### (?P<answer>([0-9])*)\\n*"
410+
/// spec:
411+
/// answer: str
412+
/// def: EXTRACTED_GROUND_TRUTH
413+
/// ```
414+
#[derive(Serialize, Deserialize, Debug, Clone)]
415+
pub struct DataBlock {
416+
pub data: Value,
417+
418+
/// Do not evaluate expressions inside strings.
419+
#[serde(skip_serializing_if = "Option::is_none")]
420+
pub raw: Option<bool>,
421+
422+
#[serde(skip_serializing_if = "Option::is_none")]
423+
pub def: Option<String>,
424+
425+
#[serde(skip_serializing_if = "Option::is_none")]
426+
pub parser: Option<PdlParser>,
427+
}
428+
389429
/// Execute a piece of Python code.
390430
///
391431
/// Example:
@@ -491,6 +531,7 @@ pub enum PdlBlock {
491531
String(String),
492532
If(IfBlock),
493533
Include(IncludeBlock),
534+
Data(DataBlock),
494535
Object(ObjectBlock),
495536
Call(CallBlock),
496537
Array(ArrayBlock),
@@ -557,6 +598,11 @@ impl From<String> for PdlResult {
557598
PdlResult::String(s)
558599
}
559600
}
601+
impl From<&bool> for PdlResult {
602+
fn from(b: &bool) -> Self {
603+
PdlResult::Bool(*b)
604+
}
605+
}
560606
impl From<Number> for PdlResult {
561607
fn from(n: Number) -> Self {
562608
PdlResult::Number(n)

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use serde_json::{from_str, to_string, Value};
2525
use serde_norway::{from_reader, from_str as from_yaml_str};
2626

2727
use crate::pdl::ast::{
28-
ArrayBlock, CallBlock, Closure, FunctionBlock, IfBlock, IncludeBlock, ListOrString,
28+
ArrayBlock, CallBlock, Closure, DataBlock, FunctionBlock, IfBlock, IncludeBlock, ListOrString,
2929
MessageBlock, ModelBlock, ObjectBlock, PdlBlock, PdlParser, PdlResult, PdlUsage,
3030
PythonCodeBlock, ReadBlock, RepeatBlock, Role, Scope, SequencingBlock, StringOrBoolean,
3131
};
@@ -102,6 +102,7 @@ impl<'a> Interpreter<'a> {
102102
PdlBlock::If(block) => self.run_if(block, context).await,
103103
PdlBlock::Include(block) => self.run_include(block, context).await,
104104
PdlBlock::Model(block) => self.run_model(block, context).await,
105+
PdlBlock::Data(block) => self.run_data(block, context).await,
105106
PdlBlock::Object(block) => self.run_object(block, context).await,
106107
PdlBlock::PythonCode(block) => self.run_python_code(block, context).await,
107108
PdlBlock::Read(block) => self.run_read(block, context).await,
@@ -183,7 +184,6 @@ impl<'a> Interpreter<'a> {
183184
})
184185
.collect::<Result<_, _>>()?,
185186
)),
186-
// v => Ok(v.clone()),
187187
}
188188
}
189189

@@ -575,6 +575,39 @@ impl<'a> Interpreter<'a> {
575575
}
576576
}
577577

578+
fn resultify(&self, value: &Value) -> PdlResult {
579+
match value {
580+
Value::Null => "".into(),
581+
Value::Bool(b) => b.into(),
582+
Value::Number(n) => n.clone().into(),
583+
Value::String(s) => s.clone().into(),
584+
Value::Array(a) => {
585+
PdlResult::List(a.iter().map(|v| self.resultify(v)).collect::<Vec<_>>())
586+
}
587+
Value::Object(m) => PdlResult::Dict(
588+
m.iter()
589+
.map(|(k, v)| (k.clone(), self.resultify(v)))
590+
.collect::<HashMap<_, _>>(),
591+
),
592+
}
593+
}
594+
595+
async fn run_data(&mut self, block: &DataBlock, _context: Context) -> Interpretation {
596+
if self.debug {
597+
eprintln!("Data raw={:?} {:?}", block.raw, block.data);
598+
}
599+
600+
let mut trace = block.clone();
601+
if let Some(true) = block.raw {
602+
let result = self.def(&block.def, &self.resultify(&block.data), &block.parser)?;
603+
Ok((result, vec![], PdlBlock::Data(trace)))
604+
} else {
605+
let result = self.def(&block.def, &self.eval_complex(&block.data)?, &block.parser)?;
606+
trace.data = from_str(to_string(&result)?.as_str())?;
607+
Ok((result, vec![], PdlBlock::Data(trace)))
608+
}
609+
}
610+
578611
async fn run_object(&mut self, block: &ObjectBlock, context: Context) -> Interpretation {
579612
if self.debug {
580613
eprintln!("Object {:?}", block.object);
@@ -583,6 +616,7 @@ impl<'a> Interpreter<'a> {
583616
let mut messages = vec![];
584617
let mut result_map = HashMap::new();
585618
let mut trace_map = HashMap::new();
619+
586620
let mut iter = block.object.iter();
587621
while let Some((k, v)) = iter.next() {
588622
let (this_result, this_messages, this_trace) =
@@ -679,28 +713,22 @@ impl<'a> Interpreter<'a> {
679713

680714
async fn process_defs(
681715
&mut self,
682-
map: &Option<HashMap<String, PdlBlock>>,
716+
defs: &Option<HashMap<String, PdlBlock>>,
683717
) -> Result<(), PdlError> {
684-
let cur_scope = self.scope.last().unwrap_or(&HashMap::new()).clone();
685-
let new_scope = match map {
686-
Some(defs) => {
687-
// this is all non-optimal
688-
let mut scope: Scope = HashMap::from(cur_scope);
689-
let mut iter = defs.iter();
690-
while let Some((var, def)) = iter.next() {
691-
let (result, _, _) = self.run_quiet(def, vec![]).await?;
692-
scope.insert(
693-
var.clone(),
694-
result,
695-
//from_str(to_string(&block).unwrap().as_str()).unwrap(),
696-
);
697-
}
698-
scope
718+
let mut new_scope: Scope = HashMap::new();
719+
if let Some(cur_scope) = self.scope.last() {
720+
new_scope.extend(cur_scope.clone());
721+
}
722+
self.scope.push(new_scope);
723+
724+
if let Some(defs) = defs {
725+
let mut iter = defs.iter();
726+
while let Some((var, def)) = iter.next() {
727+
let (result, _, _) = self.run_quiet(def, vec![]).await?;
728+
let _ = self.def(&Some(var.clone()), &result, &None);
699729
}
700-
None => cur_scope,
701-
};
730+
}
702731

703-
self.scope.push(new_scope);
704732
Ok(())
705733
}
706734

@@ -737,6 +765,7 @@ impl<'a> Interpreter<'a> {
737765
output_messages.extend(this_messages);
738766
output_blocks.push(trace);
739767
}
768+
740769
self.scope.pop();
741770

742771
let trace = block.with_items(output_blocks);

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,4 +469,43 @@ mod tests {
469469
assert_eq!(messages[0].content, "hello world 4 bye");
470470
Ok(())
471471
}
472+
473+
#[test]
474+
fn text_data_1() -> Result<(), Box<dyn Error>> {
475+
let program = json!({
476+
"include": "./tests/cli/data1.pdl"
477+
});
478+
479+
let (_, messages, _) = run_json(program, false)?;
480+
assert_eq!(messages.len(), 1);
481+
assert_eq!(messages[0].role, MessageRole::User);
482+
assert_eq!(messages[0].content, "xxxx3true");
483+
Ok(())
484+
}
485+
486+
#[test]
487+
fn text_data_2() -> Result<(), Box<dyn Error>> {
488+
let program = json!({
489+
"include": "./tests/cli/data2.pdl"
490+
});
491+
492+
let (_, messages, _) = run_json(program, false)?;
493+
assert_eq!(messages.len(), 1);
494+
assert_eq!(messages[0].role, MessageRole::User);
495+
assert_eq!(messages[0].content, "xxxx3true");
496+
Ok(())
497+
}
498+
499+
#[test]
500+
fn text_data_3() -> Result<(), Box<dyn Error>> {
501+
let program = json!({
502+
"include": "./tests/cli/data3.pdl"
503+
});
504+
505+
let (_, messages, _) = run_json(program, false)?;
506+
assert_eq!(messages.len(), 1);
507+
assert_eq!(messages[0].role, MessageRole::User);
508+
assert_eq!(messages[0].content, "${x}3true");
509+
Ok(())
510+
}
472511
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
lastOf:
2+
- def: x
3+
text:
4+
- xxxx
5+
- def: y
6+
data:
7+
n: 3
8+
x: ${x}
9+
b: true
10+
- ${y.x~y.n~y.b}
11+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
defs:
2+
x:
3+
text:
4+
- xxxx
5+
y:
6+
data:
7+
n: 3
8+
x: ${x}
9+
b: true
10+
lastOf:
11+
- ${y.x~y.n~y.b}
12+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
lastOf:
2+
- def: x
3+
text:
4+
- xxxx
5+
- def: y
6+
raw: true
7+
data:
8+
n: 3
9+
x: ${x}
10+
b: true
11+
- ${y.x~y.n~y.b}
12+

0 commit comments

Comments
 (0)