@@ -60,6 +60,7 @@ struct State {
6060 scope : Scope ,
6161 escaped_variables : Vec < String > ,
6262 messages : Messages ,
63+ id_stack : Vec < String > ,
6364}
6465
6566impl State {
@@ -70,6 +71,7 @@ impl State {
7071 scope : initial_scope,
7172 escaped_variables : vec ! [ ] ,
7273 messages : vec ! [ ] ,
74+ id_stack : vec ! [ ] ,
7375 }
7476 }
7577
@@ -85,6 +87,19 @@ impl State {
8587 s
8688 }
8789
90+ fn with_iter ( & self , iter : usize ) -> Self {
91+ let mut s = self . clone ( ) ;
92+ s. id_stack . push ( format ! ( "{iter}" ) ) ;
93+ s
94+ }
95+
96+ fn incr_iter ( & self , iter : usize ) -> Self {
97+ let mut s = self . clone ( ) ;
98+ s. id_stack . pop ( ) ;
99+ s. id_stack . push ( format ! ( "{iter}" ) ) ;
100+ s
101+ }
102+
88103 fn extend_scope ( & self , scopes : Vec < Scope > ) -> Self {
89104 let mut s = self . clone ( ) ;
90105 scopes. into_iter ( ) . for_each ( |m| s. scope . extend ( m) ) ;
@@ -94,7 +109,6 @@ impl State {
94109
95110struct Interpreter < ' a > {
96111 // batch: u32,
97- // id_stack: Vec<String>,
98112 options : RunOptions < ' a > ,
99113 jinja_env : Environment < ' a > ,
100114}
@@ -188,20 +202,62 @@ impl<'a> Interpreter<'a> {
188202 self . process_defs ( & m. defs , state) . await ?;
189203
190204 let ( result, messages, trace_body) = match & block. body {
191- Call ( b) => self . run_call ( b, m, state) . await ,
192- Data ( b) => self . run_data ( b, m, state) . await ,
193- If ( b) => self . run_if ( b, m, state) . await ,
194- Import ( b) => self . run_import ( b, m, state) . await ,
195- Include ( b) => self . run_include ( b, m, state) . await ,
196- Model ( b) => self . run_model ( b, m, state) . await ,
197- Object ( b) => self . run_object ( b, m, state) . await ,
198- PythonCode ( b) => self . run_python_code ( b, m, state) . await ,
199- Read ( b) => self . run_read ( b, m, state) . await ,
200- Repeat ( b) => self . run_repeat ( b, m, state) . await ,
201- LastOf ( b) => self . run_sequence ( b, m, state) . await ,
202- Text ( b) => self . run_sequence ( b, m, state) . await ,
203- Array ( b) => self . run_array ( b, m, state) . await ,
204- Message ( b) => self . run_message ( b, m, state) . await ,
205+ Call ( b) => {
206+ state. id_stack . push ( "call" . to_string ( ) ) ;
207+ self . run_call ( b, m, state) . await
208+ }
209+ Data ( b) => {
210+ state. id_stack . push ( "data" . to_string ( ) ) ;
211+ self . run_data ( b, m, state) . await
212+ }
213+ If ( b) => {
214+ state. id_stack . push ( "if" . to_string ( ) ) ;
215+ self . run_if ( b, m, state) . await
216+ }
217+ Import ( b) => {
218+ state. id_stack . push ( "import" . to_string ( ) ) ;
219+ self . run_import ( b, m, state) . await
220+ }
221+ Include ( b) => {
222+ state. id_stack . push ( "include" . to_string ( ) ) ;
223+ self . run_include ( b, m, state) . await
224+ }
225+ Model ( b) => {
226+ state. id_stack . push ( "model" . to_string ( ) ) ;
227+ self . run_model ( b, m, state) . await
228+ }
229+ Object ( b) => {
230+ state. id_stack . push ( "object" . to_string ( ) ) ;
231+ self . run_object ( b, m, state) . await
232+ }
233+ PythonCode ( b) => {
234+ state. id_stack . push ( "code" . to_string ( ) ) ;
235+ self . run_python_code ( b, m, state) . await
236+ }
237+ Read ( b) => {
238+ state. id_stack . push ( "read" . to_string ( ) ) ;
239+ self . run_read ( b, m, state) . await
240+ }
241+ Repeat ( b) => {
242+ state. id_stack . push ( "repeat" . to_string ( ) ) ;
243+ self . run_repeat ( b, m, state) . await
244+ }
245+ LastOf ( b) => {
246+ state. id_stack . push ( "lastOf" . to_string ( ) ) ;
247+ self . run_sequence ( b, m, state) . await
248+ }
249+ Text ( b) => {
250+ state. id_stack . push ( "text" . to_string ( ) ) ;
251+ self . run_sequence ( b, m, state) . await
252+ }
253+ Array ( b) => {
254+ state. id_stack . push ( "array" . to_string ( ) ) ;
255+ self . run_array ( b, m, state) . await
256+ }
257+ Message ( b) => {
258+ state. id_stack . push ( "message" . to_string ( ) ) ;
259+ self . run_message ( b, m, state) . await
260+ }
205261 } ?;
206262
207263 let mut trace = Block {
@@ -212,10 +268,13 @@ impl<'a> Interpreter<'a> {
212268 timing. end ( ) ?;
213269
214270 let mut trace_metadata = m. clone ( ) ;
271+ trace_metadata. pdl_id = Some ( state. id_stack . join ( "." ) ) ;
215272 trace_metadata. pdl_timing = Some ( timing) ;
216273 trace_metadata. pdl_result = Some ( Box :: new ( result. clone ( ) ) ) ;
217274 trace. metadata = Some ( trace_metadata) ;
218275
276+ state. id_stack . pop ( ) ;
277+
219278 Ok ( ( result, messages, Advanced ( trace) ) )
220279 }
221280
@@ -854,8 +913,12 @@ impl<'a> Interpreter<'a> {
854913 completion_nanos : Some ( usage. eval_duration ) ,
855914 } ) ;
856915 }
857- let output_messages = vec ! [ ChatMessage :: assistant( response_string) ] ;
858- Ok ( ( res. message . content . into ( ) , output_messages, Model ( trace) ) )
916+ let output_messages = vec ! [ ChatMessage :: assistant( response_string. clone( ) ) ] ;
917+ Ok ( (
918+ PdlResult :: String ( response_string) ,
919+ output_messages,
920+ Model ( trace) ,
921+ ) )
859922 } else {
860923 // nothing came out of the model
861924 Ok ( ( "" . into ( ) , vec ! [ ] , Model ( trace) ) )
@@ -949,10 +1012,6 @@ impl<'a> Interpreter<'a> {
9491012 _metadata : & Metadata ,
9501013 state : & mut State ,
9511014 ) -> BodyInterpretation {
952- // { i:[1,2,3], j: [4,5,6]} -> ([i,j], [[1,2,3],[4,5,6]])
953- // let (variables, values): (Vec<_>, Vec<Vec<_>>) = block
954- // .into_iter()
955- // .unzip();
9561015 let iter_scopes = block
9571016 . for_
9581017 . iter ( )
@@ -971,14 +1030,16 @@ impl<'a> Interpreter<'a> {
9711030 let mut results = vec ! [ ] ;
9721031 let mut messages = vec ! [ ] ;
9731032 let mut trace = vec ! [ ] ;
974- let mut iter_state = state. clone ( ) ;
1033+ let mut iter_state = state. with_iter ( 0 ) ;
9751034 if let Some ( n) = iter_scopes. iter ( ) . map ( |( _, v) | v. len ( ) ) . min ( ) {
9761035 for iter in 0 ..n {
9771036 let this_iter_scope = iter_scopes
9781037 . iter ( )
9791038 . map ( |( k, v) | ( k. clone ( ) , v[ iter] . clone ( ) ) )
9801039 . collect ( ) ;
981- iter_state = iter_state. extend_scope ( vec ! [ this_iter_scope] ) ;
1040+ iter_state = iter_state
1041+ . incr_iter ( iter)
1042+ . extend_scope ( vec ! [ this_iter_scope] ) ;
9821043 let ( result, ms, t) = self . run_quiet ( & block. repeat , & mut iter_state) . await ?;
9831044 results. push ( result) ;
9841045 messages. extend ( ms) ;
@@ -987,6 +1048,9 @@ impl<'a> Interpreter<'a> {
9871048 }
9881049 }
9891050
1051+ state. scope = iter_state. scope ;
1052+ state. escaped_variables = iter_state. escaped_variables ;
1053+
9901054 Ok ( ( PdlResult :: List ( results) , messages, Repeat ( block. clone ( ) ) ) )
9911055 }
9921056
@@ -1047,28 +1111,35 @@ impl<'a> Interpreter<'a> {
10471111
10481112 // here is where we iterate over the sequence items
10491113 let mut iter = block. items ( ) . iter ( ) ;
1114+ let mut idx = 0 ;
1115+ let mut iter_state = state. with_iter ( idx) ;
10501116 while let Some ( block) = iter. next ( ) {
1117+ idx += 1 ;
1118+
10511119 // run each element of the Text block
1052- let ( this_result, this_messages, trace) = self . run ( & block, state ) . await ?;
1120+ let ( this_result, this_messages, trace) = self . run ( & block, & mut iter_state ) . await ?;
10531121
1054- state. messages . extend ( this_messages. iter ( ) . cloned ( ) ) ;
1122+ iter_state = iter_state. incr_iter ( idx) ;
1123+ iter_state. messages . extend ( this_messages. iter ( ) . cloned ( ) ) ;
10551124
10561125 output_results. push ( this_result) ;
10571126 output_messages. extend ( this_messages. iter ( ) . cloned ( ) ) ;
10581127 output_blocks. push ( trace) ;
10591128 }
10601129
1061- // self.scope.pop();
1062-
10631130 let trace = block. with_items ( output_blocks) ;
10641131 let result = self . def (
10651132 & metadata. def ,
10661133 & trace. result_for ( output_results) ,
10671134 trace. parser ( ) ,
1068- state ,
1135+ & mut iter_state ,
10691136 true ,
10701137 ) ?;
10711138 let result_messages = trace. messages_for :: < ChatMessage > ( & output_messages) ;
1139+
1140+ state. scope = iter_state. scope ;
1141+ state. escaped_variables = iter_state. escaped_variables ;
1142+
10721143 Ok ( (
10731144 result,
10741145 match block. role ( ) {
@@ -1165,13 +1236,21 @@ pub async fn run<'a>(
11651236 initial_scope : Scope ,
11661237) -> Interpretation {
11671238 crate :: pdl:: pull:: pull_if_needed ( & program) . await ?;
1168-
1239+ let trace_file = options . trace . clone ( ) ;
11691240 let mut interpreter = Interpreter :: new ( options) ;
11701241 let mut state = State :: new ( initial_scope) ;
11711242 if let Some ( cwd) = cwd {
11721243 state. cwd = cwd
11731244 }
1174- interpreter. run ( & program, & mut state) . await
1245+
1246+ let res = interpreter. run ( & program, & mut state) . await ?;
1247+ if let Some ( trace_file) = trace_file {
1248+ let file = :: std:: fs:: File :: create ( trace_file) ?;
1249+ let mut writer = :: std:: io:: BufWriter :: new ( file) ;
1250+ serde_json:: to_writer ( & mut writer, & res. 2 ) ?;
1251+ }
1252+
1253+ Ok ( res)
11751254}
11761255
11771256#[ allow( dead_code) ]
0 commit comments