Skip to content

Commit e88b694

Browse files
committed
pass1 on PdlResult
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 8a2bfa7 commit e88b694

File tree

1 file changed

+59
-20
lines changed

1 file changed

+59
-20
lines changed

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,56 @@ use ollama_rs::{
2121
Ollama,
2222
};
2323

24-
use serde_json::{from_str, to_string, Map, Value};
24+
use serde::{Deserialize, Serialize};
25+
use serde_json::{from_str, to_string, Map, Number, Value};
2526
use serde_norway::{from_reader, from_str as from_yaml_str};
2627

2728
use crate::pdl::ast::{
28-
CallBlock, IfBlock, ListOrString, ModelBlock, ObjectBlock, PdlBlock, PdlParser, PdlUsage,
29+
CallBlock, FunctionBlock, IfBlock, ListOrString, ModelBlock, ObjectBlock, PdlBlock, PdlParser, PdlUsage,
2930
PythonCodeBlock, ReadBlock, RepeatBlock, Role, StringOrBoolean, TextBlock,
3031
};
3132

32-
type PdlResult = Value;
33+
#[derive(Serialize, Deserialize, Debug, Clone)]
34+
struct Closure {
35+
scope: Box<Scope>,
36+
function: FunctionBlock
37+
}
38+
39+
#[derive(Serialize, Deserialize, Debug, Clone)]
40+
#[serde(untagged)]
41+
enum PdlResult {
42+
Closure(Closure),
43+
List(Vec<Box<PdlResult>>),
44+
Object(HashMap<String, Box<PdlResult>>),
45+
Value(Value),
46+
}
47+
impl ::std::fmt::Display for PdlResult {
48+
// This trait requires `fmt` with this exact signature.
49+
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
50+
let s = match self {
51+
PdlResult::Value(Value::String(s)) => s.clone(),
52+
PdlResult::Closure(c) => to_string(&c.function).unwrap(),
53+
x => x.to_string()
54+
};
55+
write!(f, "{}", s)
56+
}
57+
}
58+
impl From<Value> for PdlResult {
59+
fn from(v: Value) -> Self {
60+
PdlResult::Value(v)
61+
}
62+
}
63+
impl From<String> for PdlResult {
64+
fn from(s: String) -> Self {
65+
PdlResult::Value(Value::String(s))
66+
}
67+
}
68+
impl From<Number> for PdlResult {
69+
fn from(n: Number) -> Self {
70+
PdlResult::Value(Value::Number(n))
71+
}
72+
}
73+
3374
type Context = Vec<ChatMessage>;
3475
type Scope = HashMap<String, Value>;
3576
type PdlError = Box<dyn Error + Send + Sync>;
@@ -83,11 +124,11 @@ impl<'a> Interpreter<'a> {
83124

84125
let (result, messages, trace) = match program {
85126
PdlBlock::Number(n) => Ok((
86-
Value::Number(n.clone()),
127+
n.clone().into(),
87128
vec![ChatMessage::user(format!("{n}"))],
88129
PdlBlock::Number(n.clone()),
89130
)),
90-
PdlBlock::Function(f) => Ok((Value::Null, vec![], PdlBlock::Function(f.clone()))),
131+
PdlBlock::Function(f) => Ok((Value::Null.into(), vec![], PdlBlock::Function(f.clone()))),
91132
PdlBlock::String(s) => self.run_string(s, context).await,
92133
PdlBlock::Call(block) => self.run_call(block, context).await,
93134
PdlBlock::If(block) => self.run_if(block, context).await,
@@ -191,7 +232,7 @@ impl<'a> Interpreter<'a> {
191232
};
192233
let messages = vec![ChatMessage::user(result.clone())];
193234

194-
Ok((Value::String(result), messages, trace))
235+
Ok((result.into(), messages, trace))
195236
}
196237

197238
fn path_to(&self, file_path: &String) -> PathBuf {
@@ -228,7 +269,7 @@ impl<'a> Interpreter<'a> {
228269
}
229270
}
230271

231-
Ok(result)
272+
Ok(result.into())
232273
}
233274

234275
/// Run a PdlBlock::Read
@@ -305,7 +346,7 @@ impl<'a> Interpreter<'a> {
305346
Value::Bool(true) => self.run_quiet(&block.then, context).await,
306347
Value::Bool(false) => match &block.else_ {
307348
Some(else_block) => self.run_quiet(&else_block, context).await,
308-
None => Ok((Value::Null, vec![], PdlBlock::If(block.clone()))),
349+
None => Ok((Value::Null.into(), vec![], PdlBlock::If(block.clone()))),
309350
},
310351
x => Err(Box::from(format!(
311352
"if block condition evaluated to non-boolean value: {:?}",
@@ -392,7 +433,7 @@ impl<'a> Interpreter<'a> {
392433
.unwrap();
393434
let messages = vec![ChatMessage::user(result_string.as_str().to_string())];
394435
let trace = PdlBlock::PythonCode(block.clone());
395-
Ok((Value::String(messages[0].content.clone()), messages, trace))
436+
Ok((messages[0].content.clone().into(), messages, trace))
396437
}
397438
Err(_) => Err(Box::from(
398439
"Python code block failed to assign a 'result' variable",
@@ -507,13 +548,13 @@ impl<'a> Interpreter<'a> {
507548
let mut message = res.message.clone();
508549
message.content = response_string;
509550
Ok((
510-
Value::String(message.content.clone()),
551+
message.content.clone().into(),
511552
vec![message],
512553
PdlBlock::Model(trace),
513554
))
514555
} else {
515556
// nothing came out of the model
516-
Ok((Value::Null, vec![], PdlBlock::Model(trace)))
557+
Ok((Value::Null.into(), vec![], PdlBlock::Model(trace)))
517558
}
518559
// dbg!(history);
519560
}
@@ -523,22 +564,23 @@ impl<'a> Interpreter<'a> {
523564

524565
async fn run_object(&mut self, block: &ObjectBlock, context: Context) -> Interpretation {
525566
let mut messages = vec![];
526-
let mut result_map = Map::new();
567+
let mut result_map = HashMap::new();
527568
let mut trace_map = HashMap::new();
528569
let mut iter = block.object.iter();
529570
while let Some((k, v)) = iter.next() {
530571
let (this_result, this_messages, this_trace) =
531572
self.run_quiet(v, context.clone()).await?;
532573
messages.extend(this_messages);
533-
result_map.insert(k.clone(), this_result);
574+
result_map.insert(k.clone(), Box::from(this_result))
575+
;
534576
trace_map.insert(k.clone(), this_trace);
535577
}
536578

537579
if self.debug {
538580
eprintln!("Object {:?}", result_map);
539581
}
540582
Ok((
541-
Value::Object(result_map),
583+
PdlResult::Object(result_map),
542584
messages,
543585
PdlBlock::Object(ObjectBlock { object: trace_map }),
544586
))
@@ -574,15 +616,15 @@ impl<'a> Interpreter<'a> {
574616
.collect();
575617
self.extend_scope_with_map(scope);
576618
let (result, ms, t) = self.run_quiet(&block.repeat, context.clone()).await?;
577-
results.push(result);
619+
results.push(Box::from(result));
578620
messages.extend(ms);
579621
trace.push(t);
580622
self.scope.pop();
581623
}
582624
}
583625

584626
Ok((
585-
Value::Array(results),
627+
PdlResult::List(results),
586628
messages,
587629
PdlBlock::Repeat(block.clone()),
588630
))
@@ -679,10 +721,7 @@ impl<'a> Interpreter<'a> {
679721

680722
let result_string = output_results
681723
.into_iter()
682-
.map(|m| match m {
683-
Value::String(s) => s,
684-
x => x.to_string(),
685-
})
724+
.map(|m| m.to_string())
686725
.collect::<Vec<_>>()
687726
.join("\n");
688727

0 commit comments

Comments
 (0)