diff --git a/pdl-live-react/src-tauri/src/pdl/ast.rs b/pdl-live-react/src-tauri/src/pdl/ast.rs index 0c4df38d0..4a0fa4d94 100644 --- a/pdl-live-react/src-tauri/src/pdl/ast.rs +++ b/pdl-live-react/src-tauri/src/pdl/ast.rs @@ -274,7 +274,7 @@ pub struct PdlUsage { #[serde(tag = "kind", rename = "model")] #[builder(setter(into, strip_option), default)] pub struct ModelBlock { - pub model: String, + pub model: EvalsTo, #[serde(skip_serializing_if = "Option::is_none")] pub parameters: Option>, #[serde(skip_serializing_if = "Option::is_none")] @@ -468,6 +468,23 @@ pub enum EvalsTo { Expr(Expr), } +impl Default for EvalsTo { + fn default() -> Self { + EvalsTo::Const("".to_string()) + } +} + +impl From<&str> for EvalsTo { + fn from(s: &str) -> Self { + EvalsTo::Const(s.to_string()) + } +} +impl From for EvalsTo { + fn from(s: String) -> Self { + EvalsTo::Const(s) + } +} + /// Conditional control structure. /// /// Example: diff --git a/pdl-live-react/src-tauri/src/pdl/extract.rs b/pdl-live-react/src-tauri/src/pdl/extract.rs index 41eb492fc..5b9d34514 100644 --- a/pdl-live-react/src-tauri/src/pdl/extract.rs +++ b/pdl-live-react/src-tauri/src/pdl/extract.rs @@ -1,4 +1,6 @@ -use crate::pdl::ast::{Block, Body::*, Metadata, PdlBlock, PdlBlock::Advanced}; +use crate::pdl::ast::{ + Block, Body::*, EvalsTo, Expr, Metadata, ModelBlock, PdlBlock, PdlBlock::Advanced, +}; /// Extract models referenced by the programs pub fn extract_models(program: &PdlBlock) -> Vec { @@ -28,7 +30,30 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec PdlBlock::Function(b) => { extract_values_iter(&b.return_, field, values); } - Advanced(Block { body: Model(b), .. }) => values.push(b.model.clone()), + Advanced(Block { + body: + Model(ModelBlock { + model: EvalsTo::::Const(m), + .. + }), + .. + }) => values.push(m.clone()), + Advanced(Block { + body: + Model(ModelBlock { + model: EvalsTo::::Jinja(m), + .. + }), + .. + }) => values.push(m.clone()), + Advanced(Block { + body: + Model(ModelBlock { + model: EvalsTo::::Expr(Expr { pdl_expr: m, .. }), + .. + }), + .. + }) => values.push(m.clone()), Advanced(Block { body: Repeat(b), .. }) => { diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index a69595b56..2a5548996 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -261,6 +261,18 @@ impl<'a> Interpreter<'a> { } } + // TODO how can we better cope with the expected String return? + fn eval_string_to_string( + &self, + expr: &EvalsTo, + state: &State, + ) -> Result { + match expr { + EvalsTo::Const(s) | EvalsTo::Jinja(s) => self.eval(s, state), + EvalsTo::Expr(e) => self.eval(&e.pdl_expr, state), + } + } + /// Evaluate an Expr to a bool fn eval_to_bool( &self, @@ -722,140 +734,149 @@ impl<'a> Interpreter<'a> { }) } - /// Run a PdlBlock::Model - async fn run_model( + async fn run_ollama_model( &mut self, + pdl_model: String, block: &ModelBlock, metadata: &Metadata, state: &mut State, ) -> BodyInterpretation { - match &block.model { - pdl_model - if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") => - { - let mut ollama = Ollama::default(); - let model = if pdl_model.starts_with("ollama/") { - &pdl_model[7..] - } else { - &pdl_model[12..] - }; + let mut ollama = Ollama::default(); + let model = if pdl_model.starts_with("ollama/") { + &pdl_model[7..] + } else { + &pdl_model[12..] + }; - let (options, tools) = self.to_ollama_model_options(&block.parameters); - if self.options.debug { - eprintln!("Model options {:?} {:?}", metadata.description, options); - eprintln!("Model tools {:?} {:?}", metadata.description, tools); - } + let (options, tools) = self.to_ollama_model_options(&block.parameters); + if self.options.debug { + eprintln!("Model options {:?} {:?}", metadata.description, options); + eprintln!("Model tools {:?} {:?}", metadata.description, tools); + } - // The input messages to the model is either: - // a) block.input, if given - // b) the current state's accumulated messages - let input_messages = match &block.input { - Some(input) => { - // TODO ignoring result, trace - let (_result, messages, _trace) = self.run_quiet(&*input, state).await?; - messages - } - None => state.messages.clone(), - }; - let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) = - match input_messages.split_last() { - Some(x) => x, - None => (&ChatMessage::user("".into()), &[]), - }; - let mut history = Vec::from(history_slice); - if self.options.debug { - eprintln!( - "Ollama {:?} model={:?} prompt={:?} history={:?}", - metadata.description, block.model, prompt, history - ); - } + // The input messages to the model is either: + // a) block.input, if given + // b) the current state's accumulated messages + let input_messages = match &block.input { + Some(input) => { + // TODO ignoring result, trace + let (_result, messages, _trace) = self.run_quiet(&*input, state).await?; + messages + } + None => state.messages.clone(), + }; + let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) = + match input_messages.split_last() { + Some(x) => x, + None => (&ChatMessage::user("".into()), &[]), + }; + let mut history = Vec::from(history_slice); + if self.options.debug { + eprintln!( + "Ollama {:?} model={:?} prompt={:?} history={:?}", + metadata.description, block.model, prompt, history + ); + } - //if state.emit { - //println!("{}", pretty_print(&input_messages)); - //} - - let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()]) - .options(options) - .tools(tools); - - let (last_res, response_string) = if !self.options.stream { - let res = ollama - .send_chat_messages_with_history(&mut history, req) - .await?; - let response_string = res.message.content.clone(); - print!("{}", response_string); - (Some(res), response_string) - } else { - let mut stream = ollama - .send_chat_messages_with_history_stream( - ::std::sync::Arc::new(::std::sync::Mutex::new(history)), - req, - //ollama.generate(GenerationRequest::new(model.into(), prompt), - ) - .await?; - // dbg!("Model result {:?}", &res); - - let emit = if let Some(_) = &block.model_response { - false - } else { - true - }; - - let mut last_res: Option = None; - let mut response_string = String::new(); - let mut stdout = stdout(); - if emit { - stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?; - } - while let Some(Ok(res)) = stream.next().await { - if emit { - stdout.write_all(b"\x1b[32m").await?; // green - stdout.write_all(res.message.content.as_bytes()).await?; - stdout.flush().await?; - stdout.write_all(b"\x1b[0m").await?; // reset color - } - response_string += res.message.content.as_str(); - last_res = Some(res); - } - if emit { - stdout.write_all(b"\n").await?; - } + //if state.emit { + //println!("{}", pretty_print(&input_messages)); + //} + + let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()]) + .options(options) + .tools(tools); + + let (last_res, response_string) = if !self.options.stream { + let res = ollama + .send_chat_messages_with_history(&mut history, req) + .await?; + let response_string = res.message.content.clone(); + print!("{}", response_string); + (Some(res), response_string) + } else { + let mut stream = ollama + .send_chat_messages_with_history_stream( + ::std::sync::Arc::new(::std::sync::Mutex::new(history)), + req, + //ollama.generate(GenerationRequest::new(model.into(), prompt), + ) + .await?; + // dbg!("Model result {:?}", &res); - (last_res, response_string) - }; - - if let Some(_) = &block.model_response { - if let Some(ref res) = last_res { - self.def( - &block.model_response, - &resultify_as_litellm(&from_str(&to_string(&res)?)?), - &None, - state, - true, - )?; - } - } + let emit = if let Some(_) = &block.model_response { + false + } else { + true + }; - let mut trace = block.clone(); - if let Some(res) = last_res { - if let Some(usage) = res.final_data { - trace.pdl_usage = Some(PdlUsage { - prompt_tokens: usage.prompt_eval_count, - prompt_nanos: Some(usage.prompt_eval_duration), - completion_tokens: usage.eval_count, - completion_nanos: Some(usage.eval_duration), - }); - } - let output_messages = vec![ChatMessage::assistant(response_string)]; - Ok((res.message.content.into(), output_messages, Model(trace))) - } else { - // nothing came out of the model - Ok(("".into(), vec![], Model(trace))) + let mut last_res: Option = None; + let mut response_string = String::new(); + let mut stdout = stdout(); + if emit { + stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?; + } + while let Some(Ok(res)) = stream.next().await { + if emit { + stdout.write_all(b"\x1b[32m").await?; // green + stdout.write_all(res.message.content.as_bytes()).await?; + stdout.flush().await?; + stdout.write_all(b"\x1b[0m").await?; // reset color } - // dbg!(history); + response_string += res.message.content.as_str(); + last_res = Some(res); + } + if emit { + stdout.write_all(b"\n").await?; + } + + (last_res, response_string) + }; + + if let Some(_) = &block.model_response { + if let Some(ref res) = last_res { + self.def( + &block.model_response, + &resultify_as_litellm(&from_str(&to_string(&res)?)?), + &None, + state, + true, + )?; } - _ => Err(Box::from(format!("Unsupported model {}", block.model))), } + + let mut trace = block.clone(); + if let Some(res) = last_res { + if let Some(usage) = res.final_data { + trace.pdl_usage = Some(PdlUsage { + prompt_tokens: usage.prompt_eval_count, + prompt_nanos: Some(usage.prompt_eval_duration), + completion_tokens: usage.eval_count, + completion_nanos: Some(usage.eval_duration), + }); + } + let output_messages = vec![ChatMessage::assistant(response_string)]; + Ok((res.message.content.into(), output_messages, Model(trace))) + } else { + // nothing came out of the model + Ok(("".into(), vec![], Model(trace))) + } + // dbg!(history); + } + + /// Run a PdlBlock::Model + async fn run_model( + &mut self, + block: &ModelBlock, + metadata: &Metadata, + state: &mut State, + ) -> BodyInterpretation { + if let PdlResult::String(s) = self.eval_string_to_string(&block.model, state)? { + if s.starts_with("ollama/") || s.starts_with("ollama_chat/") { + return self.run_ollama_model(s, block, metadata, state).await; + } + } + + Err(Box::from(format!("Unsupported model {:?}", block.model))) } /// Run a PdlBlock::Data diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs index ba171ee74..b285016b0 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs @@ -79,6 +79,26 @@ mod tests { Ok(()) } + #[test] + fn single_model_via_text_chain_expr() -> Result<(), Box> { + let (_, messages, _) = run_json( + json!({ + "text": [ + "hello", + {"model": { "pdl__expr": DEFAULT_MODEL }} + ] + }), + streaming(), + initial_scope(), + )?; + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, MessageRole::User); + assert_eq!(messages[0].content, "hello"); + assert_eq!(messages[1].role, MessageRole::Assistant); + assert!(messages[1].content.contains("Hello!")); + Ok(()) + } + #[test] fn single_model_via_text_chain() -> Result<(), Box> { let (_, messages, _) = run_json(