diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index 2e0a2033c..8bab94bb8 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -813,7 +813,8 @@ impl<'a> Interpreter<'a> { block: &ModelBlock, metadata: &Metadata, state: &mut State, - ) -> BodyInterpretation { + input_messages: Vec, + ) -> Result<(String, Option), PdlError> { let mut ollama = Ollama::default(); let model = if pdl_model.starts_with("ollama/") { &pdl_model[7..] @@ -827,19 +828,6 @@ impl<'a> Interpreter<'a> { eprintln!("Model tools {:?} {:?}", metadata.description, tools); } - let mut trace = block.clone(); - - // The input messages to the model is either: - // a) block.input, if given - // b) the current state's accumulated messages - let input_messages = match &block.input { - Some(input) => { - // TODO ignoring result and trace - let (_result, messages, _trace) = self.run_quiet(&*input, state).await?; - messages - } - None => state.messages.clone(), - }; let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) = match input_messages.split_last() { Some(x) => x, @@ -853,10 +841,6 @@ impl<'a> Interpreter<'a> { ); } - //if state.emit { - //println!("{}", pretty_print(&input_messages)); - //} - let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()]) .options(options) .tools(tools); @@ -919,6 +903,45 @@ impl<'a> Interpreter<'a> { } } + let usage = if let Some(res) = last_res { + if let Some(usage) = res.final_data { + Some(PdlUsage { + prompt_tokens: usage.prompt_eval_count, + prompt_nanos: Some(usage.prompt_eval_duration), + completion_tokens: usage.eval_count, + completion_nanos: Some(usage.eval_duration), + }) + } else { + None + } + } else { + None + }; + + Ok((response_string, usage)) + } + + /// Run a PdlBlock::Model + async fn run_model( + &mut self, + block: &ModelBlock, + metadata: &Metadata, + state: &mut State, + ) -> BodyInterpretation { + // The input messages to the model is either: + // a) block.input, if given + // b) the current state's accumulated messages + let input_messages = match &block.input { + Some(input) => { + // TODO ignoring result and trace + let (_result, messages, _trace) = self.run_quiet(&*input, state).await?; + messages + } + None => state.messages.clone(), + }; + + let mut trace = block.clone(); + // TODO, does this belong in run_advanced(), and does // trace.context belong in Metadata rather than ModelBlock trace.context = Some( @@ -948,42 +971,30 @@ impl<'a> Interpreter<'a> { .collect(), ); - if let Some(res) = last_res { - if let Some(usage) = res.final_data { - trace.pdl_usage = Some(PdlUsage { - prompt_tokens: usage.prompt_eval_count, - prompt_nanos: Some(usage.prompt_eval_duration), - completion_tokens: usage.eval_count, - completion_nanos: Some(usage.eval_duration), - }); - } - let output_messages = vec![ChatMessage::assistant(response_string.clone())]; - Ok(( - PdlResult::String(response_string), - output_messages, - Model(trace), - )) - } else { - // nothing came out of the model - Ok(("".into(), vec![], Model(trace))) - } - // dbg!(history); - } + let (response_string, usage) = + if let PdlResult::String(s) = self.eval_string_to_string(&block.model, state)? { + if s.starts_with("ollama/") || s.starts_with("ollama_chat/") { + self.run_ollama_model(s, block, metadata, state, input_messages) + .await + /*} else if s.starts_with("openai/") { + return self.run_openai_model(s, block, metadata, state, input_messages).await;*/ + } else { + Err(Box::from(format!("Unsupported model {:?}", block.model))) + } + } else { + Err(Box::from(format!( + "Model expression evaluated to non-string {:?}", + block.model + ))) + }?; - /// Run a PdlBlock::Model - async fn run_model( - &mut self, - block: &ModelBlock, - metadata: &Metadata, - state: &mut State, - ) -> BodyInterpretation { - if let PdlResult::String(s) = self.eval_string_to_string(&block.model, state)? { - if s.starts_with("ollama/") || s.starts_with("ollama_chat/") { - return self.run_ollama_model(s, block, metadata, state).await; - } - } + trace.pdl_usage = usage; - Err(Box::from(format!("Unsupported model {:?}", block.model))) + Ok(( + PdlResult::String(response_string.clone()), + vec![ChatMessage::assistant(response_string)], + Model(trace), + )) } /// Run a PdlBlock::Data