Skip to content

Commit ba5744b

Browse files
committed
messageblock and arrayblock support
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 8de0167 commit ba5744b

File tree

6 files changed

+142
-32
lines changed

6 files changed

+142
-32
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ use serde_json::{from_reader, json, to_string, Map, Value};
1212
use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
15-
CallBlock, FunctionBlock, ListOrString, MessageBlock, ModelBlock, ObjectBlock, PdlBaseType,
16-
PdlBlock, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock,
15+
ArrayBlock, CallBlock, FunctionBlock, ListOrString, MessageBlock, ModelBlock, ObjectBlock,
16+
PdlBaseType, PdlBlock, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role,
17+
TextBlock,
1718
};
1819
use crate::pdl::pip::pip_install_if_needed;
1920
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -198,7 +199,7 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
198199
text: vec![PdlBlock::Model(
199200
ModelBlock::new(model.as_str())
200201
.parameters(&strip_nulls(parameters))
201-
.input(PdlBlock::Array {
202+
.input(PdlBlock::Array(ArrayBlock {
202203
array: vec![PdlBlock::Message(MessageBlock {
203204
role: Role::Tool,
204205
description: None,
@@ -214,7 +215,7 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
214215
args: Some("${ args }".into()), // invoke with arguments as specified by the model
215216
})),
216217
})],
217-
})
218+
}))
218219
.build(),
219220
)],
220221
});

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ pub struct ModelBlock {
164164
#[serde(skip_serializing_if = "Option::is_none")]
165165
pub parameters: Option<HashMap<String, Value>>,
166166
#[serde(skip_serializing_if = "Option::is_none")]
167-
pub input: Option<Box<PdlBlock>>, // really this should be restricted to be PdlBlock::Array; how do we do this in rust?
167+
pub input: Option<Box<PdlBlock>>,
168168
#[serde(skip_serializing_if = "Option::is_none")]
169169
#[serde(rename = "modelResponse")]
170170
pub model_response: Option<String>,
@@ -346,6 +346,13 @@ pub struct IfBlock {
346346
pub defs: Option<HashMap<String, PdlBlock>>,
347347
}
348348

349+
/// Return the array of values computed by each block of the list of blocks
350+
#[derive(Serialize, Deserialize, Debug, Clone)]
351+
pub struct ArrayBlock {
352+
/// Elements of the array
353+
pub array: Vec<PdlBlock>,
354+
}
355+
349356
#[derive(Serialize, Deserialize, Debug, Clone)]
350357
#[serde(untagged)]
351358
pub enum PdlBlock {
@@ -355,7 +362,7 @@ pub enum PdlBlock {
355362
If(IfBlock),
356363
Object(ObjectBlock),
357364
Call(CallBlock),
358-
Array { array: Vec<PdlBlock> },
365+
Array(ArrayBlock),
359366
Message(MessageBlock),
360367
Repeat(RepeatBlock),
361368
Text(TextBlock),

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 92 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ use serde_json::{from_str, to_string, Number, Value};
2626
use serde_norway::{from_reader, from_str as from_yaml_str};
2727

2828
use crate::pdl::ast::{
29-
CallBlock, FunctionBlock, IfBlock, ListOrString, ModelBlock, ObjectBlock, PdlBlock, PdlParser,
30-
PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock, Role, StringOrBoolean, TextBlock,
29+
ArrayBlock, CallBlock, FunctionBlock, IfBlock, ListOrString, MessageBlock, ModelBlock,
30+
ObjectBlock, PdlBlock, PdlParser, PdlUsage, PythonCodeBlock, ReadBlock, RepeatBlock, Role,
31+
StringOrBoolean, TextBlock,
3132
};
3233

3334
#[derive(Serialize, Deserialize, Debug, Clone)]
@@ -48,23 +49,11 @@ pub enum PdlResult {
4849
Dict(HashMap<String, PdlResult>),
4950
}
5051
impl ::std::fmt::Display for PdlResult {
51-
// This trait requires `fmt` with this exact signature.
5252
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
5353
let s = to_string(&self).unwrap(); // TODO: .map_err(|e| e.to_string())?;
5454
write!(f, "{}", s)
5555
}
5656
}
57-
/*impl From<&Value> for PdlResult {
58-
fn from(v: &Value) -> Self {
59-
match v {
60-
Value::Bool(b) => PdlResult::Bool(*b),
61-
Value::String(s) => PdlResult::String(s.clone()),
62-
Value::Number(n) => PdlResult::Number(n.clone()),
63-
Value::Array(a) => PdlResult::List(a.into_iter().map(|v| v.into())),
64-
Value::Object(m) => PdlResult::Dict(m.into_iter().map(|(k,v)| (k, v.into()))),
65-
}
66-
}
67-
}*/
6857
impl From<&str> for PdlResult {
6958
fn from(s: &str) -> Self {
7059
PdlResult::String(s.to_string())
@@ -93,7 +82,6 @@ struct Interpreter<'a> {
9382
cwd: PathBuf,
9483
// id_stack: Vec<String>,
9584
jinja_env: Environment<'a>,
96-
// rt: Runtime,
9785
scope: Vec<Scope>,
9886
debug: bool,
9987
emit: bool,
@@ -116,7 +104,6 @@ impl<'a> Interpreter<'a> {
116104
cwd: current_dir().unwrap_or(PathBuf::from("/")),
117105
// id_stack: vec![],
118106
jinja_env: jinja_env,
119-
// rt: Runtime::new().unwrap(),
120107
scope: vec![Scope::new()],
121108
debug: false,
122109
emit: true,
@@ -160,14 +147,15 @@ impl<'a> Interpreter<'a> {
160147
PdlBlock::Read(block) => self.run_read(block, context).await,
161148
PdlBlock::Repeat(block) => self.run_repeat(block, context).await,
162149
PdlBlock::Text(block) => self.run_text(block, context).await,
150+
PdlBlock::Array(block) => self.run_array(block, context).await,
151+
PdlBlock::Message(block) => self.run_message(block, context).await,
163152
_ => Err(Box::from(format!("Unsupported block {:?}", program))),
164153
}?;
165154

166155
if match program {
167-
PdlBlock::Call(_) | PdlBlock::Text(_) => false,
156+
PdlBlock::Call(_) | PdlBlock::Model(_) | PdlBlock::Text(_) => false,
168157
_ => self.emit,
169158
} {
170-
// eprintln!("{:?}", program);
171159
println!("{}", pretty_print(&messages));
172160
}
173161
self.emit = prior_emit;
@@ -204,6 +192,17 @@ impl<'a> Interpreter<'a> {
204192
}))
205193
}
206194

195+
/// Evaluate String as a Jinja2 expression, expecting a string in response
196+
fn eval_to_string(&self, expr: &String) -> Result<String, PdlError> {
197+
match self.eval(expr)? {
198+
PdlResult::String(s) => Ok(s),
199+
x => Err(Box::from(format!(
200+
"Expression {expr} evaluated to non-string {:?}",
201+
x
202+
))),
203+
}
204+
}
205+
207206
fn eval_complex(&self, expr: &Value) -> Result<PdlResult, PdlError> {
208207
match expr {
209208
Value::Null => Ok("".into()),
@@ -499,7 +498,7 @@ impl<'a> Interpreter<'a> {
499498
println!("Model options {:?}", options);
500499
}
501500

502-
let messages = match &block.input {
501+
let input_messages = match &block.input {
503502
Some(input) => {
504503
// TODO ignoring result, trace
505504
let (_result, messages, _trace) = self.run_quiet(&*input, context).await?;
@@ -508,7 +507,7 @@ impl<'a> Interpreter<'a> {
508507
None => context,
509508
};
510509
let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) =
511-
match messages.split_last() {
510+
match input_messages.split_last() {
512511
Some(x) => x,
513512
None => (&ChatMessage::user("".into()), &[]),
514513
};
@@ -523,6 +522,10 @@ impl<'a> Interpreter<'a> {
523522
);
524523
}
525524

525+
if self.emit {
526+
println!("{}", pretty_print(&input_messages));
527+
}
528+
526529
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
527530
.options(options)
528531
.tools(tools);
@@ -571,6 +574,7 @@ impl<'a> Interpreter<'a> {
571574
response_string += res.message.content.as_str();
572575
last_res = Some(res);
573576
}
577+
stdout.write_all(b"\n").await?;
574578

575579
let mut trace = block.clone();
576580
trace.pdl_result = Some(response_string.clone());
@@ -584,11 +588,10 @@ impl<'a> Interpreter<'a> {
584588
completion_nanos: usage.eval_duration,
585589
});
586590
}
587-
let mut message = res.message.clone();
588-
message.content = response_string;
591+
let output_messages = vec![ChatMessage::assistant(response_string)];
589592
Ok((
590-
message.content.clone().into(),
591-
vec![message],
593+
res.message.content.into(),
594+
output_messages,
592595
PdlBlock::Model(trace),
593596
))
594597
} else {
@@ -673,7 +676,7 @@ impl<'a> Interpreter<'a> {
673676
match role {
674677
Role::User => MessageRole::User,
675678
Role::Assistant => MessageRole::Assistant,
676-
Role::System => MessageRole::Assistant,
679+
Role::System => MessageRole::System,
677680
Role::Tool => MessageRole::Tool,
678681
}
679682
}
@@ -786,6 +789,70 @@ impl<'a> Interpreter<'a> {
786789
PdlBlock::Text(trace),
787790
))
788791
}
792+
793+
async fn run_array(&mut self, block: &ArrayBlock, context: Context) -> Interpretation {
794+
let mut result_items = vec![];
795+
let mut all_messages = vec![];
796+
let mut trace_items = vec![];
797+
798+
let mut iter = block.array.iter();
799+
while let Some(item) = iter.next() {
800+
// TODO accumulate messages
801+
let (result, messages, trace) = self.run_quiet(item, context.clone()).await?;
802+
result_items.push(result);
803+
all_messages.extend(messages);
804+
trace_items.push(trace);
805+
}
806+
807+
Ok((
808+
PdlResult::List(result_items),
809+
all_messages,
810+
PdlBlock::Array(ArrayBlock { array: trace_items }),
811+
))
812+
}
813+
814+
async fn run_message(&mut self, block: &MessageBlock, context: Context) -> Interpretation {
815+
let (content_result, content_messages, content_trace) =
816+
self.run(&block.content, context).await?;
817+
let name = if let Some(name) = &block.name {
818+
Some(self.eval_to_string(&name)?)
819+
} else {
820+
None
821+
};
822+
let tool_call_id = if let Some(tool_call_id) = &block.tool_call_id {
823+
Some(self.eval_to_string(&tool_call_id)?)
824+
} else {
825+
None
826+
};
827+
828+
let mut dict: HashMap<String, PdlResult> = HashMap::new();
829+
dict.insert("role".into(), PdlResult::String(to_string(&block.role)?));
830+
dict.insert("content".into(), content_result);
831+
if let Some(name) = &name {
832+
dict.insert("name".into(), PdlResult::String(name.clone()));
833+
}
834+
if let Some(tool_call_id) = &tool_call_id {
835+
dict.insert(
836+
"tool_call_id".into(),
837+
PdlResult::String(tool_call_id.clone()),
838+
);
839+
}
840+
841+
Ok((
842+
PdlResult::Dict(dict),
843+
content_messages
844+
.into_iter()
845+
.map(|m| ChatMessage::new(self.to_ollama_role(&block.role), m.content))
846+
.collect(),
847+
PdlBlock::Message(MessageBlock {
848+
role: block.role.clone(),
849+
content: Box::new(content_trace),
850+
description: block.description.clone(),
851+
name: name,
852+
tool_call_id: tool_call_id,
853+
}),
854+
))
855+
}
789856
}
790857

791858
pub async fn run(program: &PdlBlock, cwd: Option<PathBuf>, debug: bool) -> Interpretation {

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ mod tests {
2323
}
2424

2525
#[test]
26-
fn single_model_via_input() -> Result<(), Box<dyn Error>> {
26+
fn single_model_via_input_string() -> Result<(), Box<dyn Error>> {
2727
let (_, messages, _) = run(
2828
&PdlBlock::Model(ModelBlock::new(DEFAULT_MODEL).input_str("hello").build()),
2929
None,
@@ -54,6 +54,32 @@ mod tests {
5454
Ok(())
5555
}
5656

57+
#[test]
58+
fn single_model_via_input_array() -> Result<(), Box<dyn Error>> {
59+
let (_, messages, _) = run_json(
60+
json!({
61+
"model": DEFAULT_MODEL,
62+
"input": {
63+
"array": [
64+
{ "role": "system", "content": "answer as if you live in europe" },
65+
{ "role": "user", "content": "what is the fastest animal where you live?" },
66+
]
67+
}
68+
}),
69+
false,
70+
)?;
71+
assert_eq!(messages.len(), 1);
72+
assert_eq!(messages[0].role, MessageRole::Assistant);
73+
let m = messages[0].content.to_lowercase();
74+
assert!(
75+
m.contains("pronghorn")
76+
|| m.contains("falcon")
77+
|| m.contains("bison")
78+
|| m.contains("native")
79+
);
80+
Ok(())
81+
}
82+
5783
#[test]
5884
fn two_models_via_text_chain() -> Result<(), Box<dyn Error>> {
5985
let (_, messages, _) = run_json(
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
model: ollama/granite3.2:2b
2+
input:
3+
array:
4+
- role: system
5+
content: answer as if you live in europe
6+
- role: user
7+
content: what is the fastest animal where i live?
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
model: ollama/granite3.2:2b
2+
input: what is the fastest animal?

0 commit comments

Comments
 (0)