diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index 2a5548996..2c8c4b1f8 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -60,6 +60,7 @@ struct State { scope: Scope, escaped_variables: Vec, messages: Messages, + id_stack: Vec, } impl State { @@ -70,6 +71,7 @@ impl State { scope: initial_scope, escaped_variables: vec![], messages: vec![], + id_stack: vec![], } } @@ -85,6 +87,19 @@ impl State { s } + fn with_iter(&self, iter: usize) -> Self { + let mut s = self.clone(); + s.id_stack.push(format!("{iter}")); + s + } + + fn incr_iter(&self, iter: usize) -> Self { + let mut s = self.clone(); + s.id_stack.pop(); + s.id_stack.push(format!("{iter}")); + s + } + fn extend_scope(&self, scopes: Vec) -> Self { let mut s = self.clone(); scopes.into_iter().for_each(|m| s.scope.extend(m)); @@ -94,7 +109,6 @@ impl State { struct Interpreter<'a> { // batch: u32, - // id_stack: Vec, options: RunOptions<'a>, jinja_env: Environment<'a>, } @@ -188,20 +202,62 @@ impl<'a> Interpreter<'a> { self.process_defs(&m.defs, state).await?; let (result, messages, trace_body) = match &block.body { - Call(b) => self.run_call(b, m, state).await, - Data(b) => self.run_data(b, m, state).await, - If(b) => self.run_if(b, m, state).await, - Import(b) => self.run_import(b, m, state).await, - Include(b) => self.run_include(b, m, state).await, - Model(b) => self.run_model(b, m, state).await, - Object(b) => self.run_object(b, m, state).await, - PythonCode(b) => self.run_python_code(b, m, state).await, - Read(b) => self.run_read(b, m, state).await, - Repeat(b) => self.run_repeat(b, m, state).await, - LastOf(b) => self.run_sequence(b, m, state).await, - Text(b) => self.run_sequence(b, m, state).await, - Array(b) => self.run_array(b, m, state).await, - Message(b) => self.run_message(b, m, state).await, + Call(b) => { + state.id_stack.push("call".to_string()); + self.run_call(b, m, state).await + } + Data(b) => { + state.id_stack.push("data".to_string()); + self.run_data(b, m, state).await + } + If(b) => { + state.id_stack.push("if".to_string()); + self.run_if(b, m, state).await + } + Import(b) => { + state.id_stack.push("import".to_string()); + self.run_import(b, m, state).await + } + Include(b) => { + state.id_stack.push("include".to_string()); + self.run_include(b, m, state).await + } + Model(b) => { + state.id_stack.push("model".to_string()); + self.run_model(b, m, state).await + } + Object(b) => { + state.id_stack.push("object".to_string()); + self.run_object(b, m, state).await + } + PythonCode(b) => { + state.id_stack.push("code".to_string()); + self.run_python_code(b, m, state).await + } + Read(b) => { + state.id_stack.push("read".to_string()); + self.run_read(b, m, state).await + } + Repeat(b) => { + state.id_stack.push("repeat".to_string()); + self.run_repeat(b, m, state).await + } + LastOf(b) => { + state.id_stack.push("lastOf".to_string()); + self.run_sequence(b, m, state).await + } + Text(b) => { + state.id_stack.push("text".to_string()); + self.run_sequence(b, m, state).await + } + Array(b) => { + state.id_stack.push("array".to_string()); + self.run_array(b, m, state).await + } + Message(b) => { + state.id_stack.push("message".to_string()); + self.run_message(b, m, state).await + } }?; let mut trace = Block { @@ -212,10 +268,13 @@ impl<'a> Interpreter<'a> { timing.end()?; let mut trace_metadata = m.clone(); + trace_metadata.pdl_id = Some(state.id_stack.join(".")); trace_metadata.pdl_timing = Some(timing); trace_metadata.pdl_result = Some(Box::new(result.clone())); trace.metadata = Some(trace_metadata); + state.id_stack.pop(); + Ok((result, messages, Advanced(trace))) } @@ -854,8 +913,12 @@ impl<'a> Interpreter<'a> { completion_nanos: Some(usage.eval_duration), }); } - let output_messages = vec![ChatMessage::assistant(response_string)]; - Ok((res.message.content.into(), output_messages, Model(trace))) + let output_messages = vec![ChatMessage::assistant(response_string.clone())]; + Ok(( + PdlResult::String(response_string), + output_messages, + Model(trace), + )) } else { // nothing came out of the model Ok(("".into(), vec![], Model(trace))) @@ -949,10 +1012,6 @@ impl<'a> Interpreter<'a> { _metadata: &Metadata, state: &mut State, ) -> BodyInterpretation { - // { i:[1,2,3], j: [4,5,6]} -> ([i,j], [[1,2,3],[4,5,6]]) - // let (variables, values): (Vec<_>, Vec>) = block - // .into_iter() - // .unzip(); let iter_scopes = block .for_ .iter() @@ -971,14 +1030,16 @@ impl<'a> Interpreter<'a> { let mut results = vec![]; let mut messages = vec![]; let mut trace = vec![]; - let mut iter_state = state.clone(); + let mut iter_state = state.with_iter(0); if let Some(n) = iter_scopes.iter().map(|(_, v)| v.len()).min() { for iter in 0..n { let this_iter_scope = iter_scopes .iter() .map(|(k, v)| (k.clone(), v[iter].clone())) .collect(); - iter_state = iter_state.extend_scope(vec![this_iter_scope]); + iter_state = iter_state + .incr_iter(iter) + .extend_scope(vec![this_iter_scope]); let (result, ms, t) = self.run_quiet(&block.repeat, &mut iter_state).await?; results.push(result); messages.extend(ms); @@ -987,6 +1048,9 @@ impl<'a> Interpreter<'a> { } } + state.scope = iter_state.scope; + state.escaped_variables = iter_state.escaped_variables; + Ok((PdlResult::List(results), messages, Repeat(block.clone()))) } @@ -1047,28 +1111,35 @@ impl<'a> Interpreter<'a> { // here is where we iterate over the sequence items let mut iter = block.items().iter(); + let mut idx = 0; + let mut iter_state = state.with_iter(idx); while let Some(block) = iter.next() { + idx += 1; + // run each element of the Text block - let (this_result, this_messages, trace) = self.run(&block, state).await?; + let (this_result, this_messages, trace) = self.run(&block, &mut iter_state).await?; - state.messages.extend(this_messages.iter().cloned()); + iter_state = iter_state.incr_iter(idx); + iter_state.messages.extend(this_messages.iter().cloned()); output_results.push(this_result); output_messages.extend(this_messages.iter().cloned()); output_blocks.push(trace); } - // self.scope.pop(); - let trace = block.with_items(output_blocks); let result = self.def( &metadata.def, &trace.result_for(output_results), trace.parser(), - state, + &mut iter_state, true, )?; let result_messages = trace.messages_for::(&output_messages); + + state.scope = iter_state.scope; + state.escaped_variables = iter_state.escaped_variables; + Ok(( result, match block.role() { @@ -1165,13 +1236,21 @@ pub async fn run<'a>( initial_scope: Scope, ) -> Interpretation { crate::pdl::pull::pull_if_needed(&program).await?; - + let trace_file = options.trace.clone(); let mut interpreter = Interpreter::new(options); let mut state = State::new(initial_scope); if let Some(cwd) = cwd { state.cwd = cwd } - interpreter.run(&program, &mut state).await + + let res = interpreter.run(&program, &mut state).await?; + if let Some(trace_file) = trace_file { + let file = ::std::fs::File::create(trace_file)?; + let mut writer = ::std::io::BufWriter::new(file); + serde_json::to_writer(&mut writer, &res.2)?; + } + + Ok(res) } #[allow(dead_code)]