Skip to content

Commit 10d0f54

Browse files
committed
feat: initial pdl__id and --trace support for rust interpreter
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 316a31b commit 10d0f54

File tree

1 file changed

+109
-30
lines changed

1 file changed

+109
-30
lines changed

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct State {
6060
scope: Scope,
6161
escaped_variables: Vec<String>,
6262
messages: Messages,
63+
id_stack: Vec<String>,
6364
}
6465

6566
impl State {
@@ -70,6 +71,7 @@ impl State {
7071
scope: initial_scope,
7172
escaped_variables: vec![],
7273
messages: vec![],
74+
id_stack: vec![],
7375
}
7476
}
7577

@@ -85,6 +87,19 @@ impl State {
8587
s
8688
}
8789

90+
fn with_iter(&self, iter: usize) -> Self {
91+
let mut s = self.clone();
92+
s.id_stack.push(format!("{iter}"));
93+
s
94+
}
95+
96+
fn incr_iter(&self, iter: usize) -> Self {
97+
let mut s = self.clone();
98+
s.id_stack.pop();
99+
s.id_stack.push(format!("{iter}"));
100+
s
101+
}
102+
88103
fn extend_scope(&self, scopes: Vec<Scope>) -> Self {
89104
let mut s = self.clone();
90105
scopes.into_iter().for_each(|m| s.scope.extend(m));
@@ -94,7 +109,6 @@ impl State {
94109

95110
struct Interpreter<'a> {
96111
// batch: u32,
97-
// id_stack: Vec<String>,
98112
options: RunOptions<'a>,
99113
jinja_env: Environment<'a>,
100114
}
@@ -188,20 +202,62 @@ impl<'a> Interpreter<'a> {
188202
self.process_defs(&m.defs, state).await?;
189203

190204
let (result, messages, trace_body) = match &block.body {
191-
Call(b) => self.run_call(b, m, state).await,
192-
Data(b) => self.run_data(b, m, state).await,
193-
If(b) => self.run_if(b, m, state).await,
194-
Import(b) => self.run_import(b, m, state).await,
195-
Include(b) => self.run_include(b, m, state).await,
196-
Model(b) => self.run_model(b, m, state).await,
197-
Object(b) => self.run_object(b, m, state).await,
198-
PythonCode(b) => self.run_python_code(b, m, state).await,
199-
Read(b) => self.run_read(b, m, state).await,
200-
Repeat(b) => self.run_repeat(b, m, state).await,
201-
LastOf(b) => self.run_sequence(b, m, state).await,
202-
Text(b) => self.run_sequence(b, m, state).await,
203-
Array(b) => self.run_array(b, m, state).await,
204-
Message(b) => self.run_message(b, m, state).await,
205+
Call(b) => {
206+
state.id_stack.push("call".to_string());
207+
self.run_call(b, m, state).await
208+
}
209+
Data(b) => {
210+
state.id_stack.push("data".to_string());
211+
self.run_data(b, m, state).await
212+
}
213+
If(b) => {
214+
state.id_stack.push("if".to_string());
215+
self.run_if(b, m, state).await
216+
}
217+
Import(b) => {
218+
state.id_stack.push("import".to_string());
219+
self.run_import(b, m, state).await
220+
}
221+
Include(b) => {
222+
state.id_stack.push("include".to_string());
223+
self.run_include(b, m, state).await
224+
}
225+
Model(b) => {
226+
state.id_stack.push("model".to_string());
227+
self.run_model(b, m, state).await
228+
}
229+
Object(b) => {
230+
state.id_stack.push("object".to_string());
231+
self.run_object(b, m, state).await
232+
}
233+
PythonCode(b) => {
234+
state.id_stack.push("code".to_string());
235+
self.run_python_code(b, m, state).await
236+
}
237+
Read(b) => {
238+
state.id_stack.push("read".to_string());
239+
self.run_read(b, m, state).await
240+
}
241+
Repeat(b) => {
242+
state.id_stack.push("repeat".to_string());
243+
self.run_repeat(b, m, state).await
244+
}
245+
LastOf(b) => {
246+
state.id_stack.push("lastOf".to_string());
247+
self.run_sequence(b, m, state).await
248+
}
249+
Text(b) => {
250+
state.id_stack.push("text".to_string());
251+
self.run_sequence(b, m, state).await
252+
}
253+
Array(b) => {
254+
state.id_stack.push("array".to_string());
255+
self.run_array(b, m, state).await
256+
}
257+
Message(b) => {
258+
state.id_stack.push("message".to_string());
259+
self.run_message(b, m, state).await
260+
}
205261
}?;
206262

207263
let mut trace = Block {
@@ -212,10 +268,13 @@ impl<'a> Interpreter<'a> {
212268
timing.end()?;
213269

214270
let mut trace_metadata = m.clone();
271+
trace_metadata.pdl_id = Some(state.id_stack.join("."));
215272
trace_metadata.pdl_timing = Some(timing);
216273
trace_metadata.pdl_result = Some(Box::new(result.clone()));
217274
trace.metadata = Some(trace_metadata);
218275

276+
state.id_stack.pop();
277+
219278
Ok((result, messages, Advanced(trace)))
220279
}
221280

@@ -854,8 +913,12 @@ impl<'a> Interpreter<'a> {
854913
completion_nanos: Some(usage.eval_duration),
855914
});
856915
}
857-
let output_messages = vec![ChatMessage::assistant(response_string)];
858-
Ok((res.message.content.into(), output_messages, Model(trace)))
916+
let output_messages = vec![ChatMessage::assistant(response_string.clone())];
917+
Ok((
918+
PdlResult::String(response_string),
919+
output_messages,
920+
Model(trace),
921+
))
859922
} else {
860923
// nothing came out of the model
861924
Ok(("".into(), vec![], Model(trace)))
@@ -949,10 +1012,6 @@ impl<'a> Interpreter<'a> {
9491012
_metadata: &Metadata,
9501013
state: &mut State,
9511014
) -> BodyInterpretation {
952-
// { i:[1,2,3], j: [4,5,6]} -> ([i,j], [[1,2,3],[4,5,6]])
953-
// let (variables, values): (Vec<_>, Vec<Vec<_>>) = block
954-
// .into_iter()
955-
// .unzip();
9561015
let iter_scopes = block
9571016
.for_
9581017
.iter()
@@ -971,14 +1030,16 @@ impl<'a> Interpreter<'a> {
9711030
let mut results = vec![];
9721031
let mut messages = vec![];
9731032
let mut trace = vec![];
974-
let mut iter_state = state.clone();
1033+
let mut iter_state = state.with_iter(0);
9751034
if let Some(n) = iter_scopes.iter().map(|(_, v)| v.len()).min() {
9761035
for iter in 0..n {
9771036
let this_iter_scope = iter_scopes
9781037
.iter()
9791038
.map(|(k, v)| (k.clone(), v[iter].clone()))
9801039
.collect();
981-
iter_state = iter_state.extend_scope(vec![this_iter_scope]);
1040+
iter_state = iter_state
1041+
.incr_iter(iter)
1042+
.extend_scope(vec![this_iter_scope]);
9821043
let (result, ms, t) = self.run_quiet(&block.repeat, &mut iter_state).await?;
9831044
results.push(result);
9841045
messages.extend(ms);
@@ -987,6 +1048,9 @@ impl<'a> Interpreter<'a> {
9871048
}
9881049
}
9891050

1051+
state.scope = iter_state.scope;
1052+
state.escaped_variables = iter_state.escaped_variables;
1053+
9901054
Ok((PdlResult::List(results), messages, Repeat(block.clone())))
9911055
}
9921056

@@ -1047,28 +1111,35 @@ impl<'a> Interpreter<'a> {
10471111

10481112
// here is where we iterate over the sequence items
10491113
let mut iter = block.items().iter();
1114+
let mut idx = 0;
1115+
let mut iter_state = state.with_iter(idx);
10501116
while let Some(block) = iter.next() {
1117+
idx += 1;
1118+
10511119
// run each element of the Text block
1052-
let (this_result, this_messages, trace) = self.run(&block, state).await?;
1120+
let (this_result, this_messages, trace) = self.run(&block, &mut iter_state).await?;
10531121

1054-
state.messages.extend(this_messages.iter().cloned());
1122+
iter_state = iter_state.incr_iter(idx);
1123+
iter_state.messages.extend(this_messages.iter().cloned());
10551124

10561125
output_results.push(this_result);
10571126
output_messages.extend(this_messages.iter().cloned());
10581127
output_blocks.push(trace);
10591128
}
10601129

1061-
// self.scope.pop();
1062-
10631130
let trace = block.with_items(output_blocks);
10641131
let result = self.def(
10651132
&metadata.def,
10661133
&trace.result_for(output_results),
10671134
trace.parser(),
1068-
state,
1135+
&mut iter_state,
10691136
true,
10701137
)?;
10711138
let result_messages = trace.messages_for::<ChatMessage>(&output_messages);
1139+
1140+
state.scope = iter_state.scope;
1141+
state.escaped_variables = iter_state.escaped_variables;
1142+
10721143
Ok((
10731144
result,
10741145
match block.role() {
@@ -1165,13 +1236,21 @@ pub async fn run<'a>(
11651236
initial_scope: Scope,
11661237
) -> Interpretation {
11671238
crate::pdl::pull::pull_if_needed(&program).await?;
1168-
1239+
let trace_file = options.trace.clone();
11691240
let mut interpreter = Interpreter::new(options);
11701241
let mut state = State::new(initial_scope);
11711242
if let Some(cwd) = cwd {
11721243
state.cwd = cwd
11731244
}
1174-
interpreter.run(&program, &mut state).await
1245+
1246+
let res = interpreter.run(&program, &mut state).await?;
1247+
if let Some(trace_file) = trace_file {
1248+
let file = ::std::fs::File::create(trace_file)?;
1249+
let mut writer = ::std::io::BufWriter::new(file);
1250+
serde_json::to_writer(&mut writer, &res.2)?;
1251+
}
1252+
1253+
Ok(res)
11751254
}
11761255

11771256
#[allow(dead_code)]

0 commit comments

Comments
 (0)