From d75b6a77bcf9aa94f4014e5a7db5baa8b992cd0c Mon Sep 17 00:00:00 2001 From: Nick Mitchell Date: Wed, 9 Apr 2025 10:05:41 -0400 Subject: [PATCH] feat: rust interpreter support for modelResponse, ollama-rs tooling calling, and no-stream This adds a test for the bee compiler, and that test also runs the compiled code. Signed-off-by: Nick Mitchell --- pdl-live-react/src-tauri/Cargo.lock | 4 +- pdl-live-react/src-tauri/Cargo.toml | 4 +- pdl-live-react/src-tauri/src/cli.rs | 7 +- pdl-live-react/src-tauri/src/compile/beeai.rs | 24 ++- .../src-tauri/src/pdl/interpreter.rs | 202 ++++++++++++------ .../src-tauri/src/pdl/interpreter_tests.rs | 28 ++- pdl-live-react/src-tauri/tauri.conf.json | 3 + pdl-live-react/src-tauri/tests/data/bee_1.py | 64 ++++++ 8 files changed, 254 insertions(+), 82 deletions(-) create mode 100644 pdl-live-react/src-tauri/tests/data/bee_1.py diff --git a/pdl-live-react/src-tauri/Cargo.lock b/pdl-live-react/src-tauri/Cargo.lock index 1779e2668..f3f6acca7 100644 --- a/pdl-live-react/src-tauri/Cargo.lock +++ b/pdl-live-react/src-tauri/Cargo.lock @@ -3168,8 +3168,7 @@ dependencies = [ [[package]] name = "ollama-rs" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4b4750770584c8b4a643d0329e7bedacc4ecf68b7c7ac3e1fec2bafd6312f7" +source = "git+https://github.com/starpit/ollama-rs.git?branch=tools-pub-7#90820473b4b6e037a82f4c5bd2e254935bae8da7" dependencies = [ "async-stream", "log", @@ -3376,6 +3375,7 @@ dependencies = [ "rayon", "rustpython-stdlib", "rustpython-vm", + "schemars", "serde", "serde_json", "serde_norway", diff --git a/pdl-live-react/src-tauri/Cargo.toml b/pdl-live-react/src-tauri/Cargo.toml index 62534f2e5..a26aa414a 100644 --- a/pdl-live-react/src-tauri/Cargo.toml +++ b/pdl-live-react/src-tauri/Cargo.toml @@ -35,7 +35,8 @@ base64ct = { version = "1.7.1", features = ["alloc"] } dirs = "6.0.0" serde_norway = "0.9.42" minijinja = { version = "2.9.0", features = ["custom_syntax"] } -ollama-rs = { version = "0.3.0", features = ["stream"] } +#ollama-rs = { version = "0.3.0", features = ["stream"] } +ollama-rs = { git = "https://github.com/starpit/ollama-rs.git", branch = "tools-pub-7", features = ["stream"] } owo-colors = "4.2.0" rustpython-vm = "0.4.0" async-recursion = "1.1.1" @@ -43,6 +44,7 @@ tokio-stream = "0.1.17" tokio = { version = "1.44.1", features = ["io-std"] } indexmap = { version = "2.9.0", features = ["serde"] } rustpython-stdlib = { version = "0.4.0", features = ["zlib"] } +schemars = "0.8.22" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-cli = "2" diff --git a/pdl-live-react/src-tauri/src/cli.rs b/pdl-live-react/src-tauri/src/cli.rs index 06e15541b..2e2909031 100644 --- a/pdl-live-react/src-tauri/src/cli.rs +++ b/pdl-live-react/src-tauri/src/cli.rs @@ -34,7 +34,7 @@ pub fn setup(app: &mut tauri::App) -> Result> let args = compile_subcommand_matches.matches.args; match compile_subcommand_matches.name.as_str() { - "beeai" => compile::beeai::compile( + "beeai" => compile::beeai::compile_to_file( args.get("source") .and_then(|a| a.value.as_str()) .expect("valid positional source arg"), @@ -60,6 +60,11 @@ pub fn setup(app: &mut tauri::App) -> Result> .and_then(|a| a.value.as_bool()) .or(Some(false)) == Some(true), + subcommand_args + .get("no-stream") + .and_then(|a| a.value.as_bool()) + .or(Some(false)) + == Some(false), ) .and_then(|_trace| Ok(true)), "run" => run_pdl_program( diff --git a/pdl-live-react/src-tauri/src/compile/beeai.rs b/pdl-live-react/src-tauri/src/compile/beeai.rs index c686f2119..bd589fc94 100644 --- a/pdl-live-react/src-tauri/src/compile/beeai.rs +++ b/pdl-live-react/src-tauri/src/compile/beeai.rs @@ -348,15 +348,7 @@ fn python_source_to_json(source_file_path: &str, debug: bool) -> Result Result<(), Box> { - if debug { - eprintln!("Compiling beeai {} to {}", source_file_path, output_path); - } - +pub fn compile(source_file_path: &str, debug: bool) -> Result> { let file = match Path::new(source_file_path) .extension() .and_then(OsStr::to_str) @@ -559,6 +551,20 @@ asyncio.run(invoke()) text: body, }); + Ok(pdl) +} + +pub fn compile_to_file( + source_file_path: &str, + output_path: &str, + debug: bool, +) -> Result<(), Box> { + if debug { + eprintln!("Compiling beeai {} to {}", source_file_path, output_path); + } + + let pdl = compile(source_file_path, debug)?; + match output_path { "-" => println!("{}", to_string(&pdl)?), _ => { diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter.rs b/pdl-live-react/src-tauri/src/pdl/interpreter.rs index 112cccf6b..a0e9aac93 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter.rs @@ -15,13 +15,13 @@ use tokio_stream::StreamExt; use ollama_rs::{ generation::{ chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponse, MessageRole}, - tools::ToolInfo, + tools::{ToolFunctionInfo, ToolInfo, ToolType}, }, models::ModelOptions, Ollama, }; -use serde_json::{from_str, to_string, Value}; +use serde_json::{from_str, json, to_string, Value}; use serde_norway::{from_reader, from_str as from_yaml_str}; use crate::pdl::ast::{ @@ -45,6 +45,7 @@ struct Interpreter<'a> { scope: Vec, debug: bool, emit: bool, + stream: bool, } impl<'a> Interpreter<'a> { @@ -67,6 +68,7 @@ impl<'a> Interpreter<'a> { scope: vec![Scope::new()], debug: false, emit: true, + stream: true, } } @@ -76,14 +78,13 @@ impl<'a> Interpreter<'a> { context: Context, emit: bool, ) -> Interpretation { - if self.debug { + /* if self.debug { if let Some(scope) = self.scope.last() { if scope.len() > 0 { eprintln!("Run with Scope {:?}", scope); } } - } - + } */ let prior_emit = self.emit; self.emit = emit; @@ -118,7 +119,9 @@ impl<'a> Interpreter<'a> { }?; if match program { - PdlBlock::Call(_) | PdlBlock::Model(_) => false, + PdlBlock::Text(_) | PdlBlock::LastOf(_) | PdlBlock::Call(_) | PdlBlock::Model(_) => { + false + } _ => self.emit, } { println!("{}", pretty_print(&messages)); @@ -150,7 +153,7 @@ impl<'a> Interpreter<'a> { let backup = result.clone(); Ok(from_str(&result).unwrap_or_else(|err| { if self.debug { - eprintln!("Treating as plain string {}", &result); + eprintln!("Treating as plain string {}", result); eprintln!("... due to {}", err); } backup.into() @@ -332,7 +335,10 @@ impl<'a> Interpreter<'a> { self.run(&c.function.return_, context.clone()).await } - _ => Err(Box::from(format!("call of non-function {:?}", &block.call))), + x => Err(Box::from(format!( + "call of non-function {:?}->{:?}", + block.call, x + ))), }; if let Some(_) = block.args { @@ -438,10 +444,36 @@ impl<'a> Interpreter<'a> { 0.0 }; - let tools = if let Some(Value::Array(_tools)) = parameters.get(&"tools".to_string()) { - // TODO - //tools.into_iter().map(|tool| function!()).collect() - vec![] + let tools = if let Some(Value::Array(tools)) = parameters.get("tools") { + tools + .into_iter() + .filter_map(|tool| tool.get("function")) + .filter_map(|tool| { + //from_str(&to_string(tool)?) + match ( + tool.get("name"), + tool.get("description"), + tool.get("parameters"), + ) { + ( + Some(Value::String(name)), + Some(Value::String(description)), + Some(Value::Object(parameters)), + ) => Some(ToolInfo { + tool_type: ToolType::Function, + function: ToolFunctionInfo { + name: name.to_string(), + description: description.to_string(), + parameters: schemars::schema_for_value!(parameters), + }, + }), + _ => { + eprintln!("Error: tools do not satisfy schema {:?}", tool); + None + } + } + }) + .collect() } else { vec![] }; @@ -517,7 +549,7 @@ impl<'a> Interpreter<'a> { pdl_model if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") => { - let ollama = Ollama::default(); + let mut ollama = Ollama::default(); let model = if pdl_model.starts_with("ollama/") { &pdl_model[7..] } else { @@ -526,7 +558,8 @@ impl<'a> Interpreter<'a> { let (options, tools) = self.to_ollama_model_options(&block.parameters); if self.debug { - println!("Model options {:?}", options); + eprintln!("Model options {:?} {:?}", block.description, options); + eprintln!("Model tools {:?} {:?}", block.description, tools); } let input_messages = match &block.input { @@ -542,7 +575,7 @@ impl<'a> Interpreter<'a> { Some(x) => x, None => (&ChatMessage::user("".into()), &[]), }; - let history = Vec::from(history_slice); + let mut history = Vec::from(history_slice); if self.debug { eprintln!( "Ollama {:?} model={:?} prompt={:?} history={:?}", @@ -560,52 +593,62 @@ impl<'a> Interpreter<'a> { let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()]) .options(options) .tools(tools); - /* if we ever want non-streaming: - let res = ollama - .send_chat_messages_with_history( - &mut history, - req, - //ollama.generate(GenerationRequest::new(model.into(), prompt), - ) - .await?; - // dbg!("Model result {:?}", &res); - let mut trace = block.clone(); - trace.pdl_result = Some(res.message.content.clone()); - - if let Some(usage) = res.final_data { - trace.pdl_usage = Some(PdlUsage { - prompt_tokens: usage.prompt_eval_count, - prompt_nanos: usage.prompt_eval_duration, - completion_tokens: usage.eval_count, - completion_nanos: usage.eval_duration, - }); - } - // dbg!(history); - Ok((vec![res.message], PdlBlock::Model(trace))) - */ - let mut stream = ollama - .send_chat_messages_with_history_stream( - Arc::new(Mutex::new(history)), - req, - //ollama.generate(GenerationRequest::new(model.into(), prompt), - ) - .await?; - // dbg!("Model result {:?}", &res); - - let mut last_res: Option = None; - let mut response_string = String::new(); - let mut stdout = stdout(); - stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?; - while let Some(Ok(res)) = stream.next().await { - stdout.write_all(b"\x1b[32m").await?; // green - stdout.write_all(res.message.content.as_bytes()).await?; - stdout.flush().await?; - stdout.write_all(b"\x1b[0m").await?; // reset color - response_string += res.message.content.as_str(); - last_res = Some(res); + let (last_res, response_string) = if !self.stream { + let res = ollama + .send_chat_messages_with_history(&mut history, req) + .await?; + let response_string = res.message.content.clone(); + print!("{}", response_string); + (Some(res), response_string) + } else { + let mut stream = ollama + .send_chat_messages_with_history_stream( + Arc::new(Mutex::new(history)), + req, + //ollama.generate(GenerationRequest::new(model.into(), prompt), + ) + .await?; + // dbg!("Model result {:?}", &res); + + let emit = if let Some(_) = &block.model_response { + false + } else { + true + }; + + let mut last_res: Option = None; + let mut response_string = String::new(); + let mut stdout = stdout(); + if emit { + stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?; + } + while let Some(Ok(res)) = stream.next().await { + if emit { + stdout.write_all(b"\x1b[32m").await?; // green + stdout.write_all(res.message.content.as_bytes()).await?; + stdout.flush().await?; + stdout.write_all(b"\x1b[0m").await?; // reset color + } + response_string += res.message.content.as_str(); + last_res = Some(res); + } + if emit { + stdout.write_all(b"\n").await?; + } + + (last_res, response_string) + }; + + if let Some(_) = &block.model_response { + if let Some(ref res) = last_res { + self.def( + &block.model_response, + &self.resultify_as_litellm(&from_str(&to_string(&res)?)?), + &None, + )?; + } } - stdout.write_all(b"\n").await?; let mut trace = block.clone(); trace.pdl_result = Some(response_string.clone()); @@ -653,6 +696,15 @@ impl<'a> Interpreter<'a> { } } + /// Transform a JSON Value into a PdlResult object that is compatible with litellm's model response schema + fn resultify_as_litellm(&self, value: &Value) -> PdlResult { + self.resultify(&json!({ + "choices": [ + value + ] + })) + } + /// Run a PdlBlock::Data async fn run_data(&mut self, block: &DataBlock, _context: Context) -> Interpretation { if self.debug { @@ -821,7 +873,7 @@ impl<'a> Interpreter<'a> { while let Some(block) = iter.next() { // run each element of the Text block let (this_result, this_messages, trace) = - self.run_quiet(&block, input_messages.clone()).await?; + self.run(&block, input_messages.clone()).await?; input_messages.extend(this_messages.clone()); output_results.push(this_result); @@ -918,17 +970,29 @@ impl<'a> Interpreter<'a> { } } -pub async fn run(program: &PdlBlock, cwd: Option, debug: bool) -> Interpretation { +pub async fn run( + program: &PdlBlock, + cwd: Option, + debug: bool, + stream: bool, +) -> Interpretation { let mut interpreter = Interpreter::new(); interpreter.debug = debug; + interpreter.stream = stream; if let Some(cwd) = cwd { interpreter.cwd = cwd }; interpreter.run(&program, vec![]).await } -pub fn run_sync(program: &PdlBlock, cwd: Option, debug: bool) -> InterpretationSync { - tauri::async_runtime::block_on(run(program, cwd, debug)) +#[allow(dead_code)] +pub fn run_sync( + program: &PdlBlock, + cwd: Option, + debug: bool, + stream: bool, +) -> InterpretationSync { + tauri::async_runtime::block_on(run(program, cwd, debug, stream)) .map_err(|err| Box::::from(err.to_string())) } @@ -938,28 +1002,30 @@ pub fn parse_file(path: &PathBuf) -> Result { .map_err(|err| Box::::from(err.to_string())) } -pub async fn run_file(source_file_path: &str, debug: bool) -> Interpretation { +pub async fn run_file(source_file_path: &str, debug: bool, stream: bool) -> Interpretation { let path = PathBuf::from(source_file_path); let cwd = path.parent().and_then(|cwd| Some(cwd.to_path_buf())); let program = parse_file(&path)?; crate::pdl::pull::pull_if_needed(&program).await?; - run(&program, cwd, debug).await + run(&program, cwd, debug, stream).await } -pub fn run_file_sync(source_file_path: &str, debug: bool) -> InterpretationSync { - tauri::async_runtime::block_on(run_file(source_file_path, debug)) +pub fn run_file_sync(source_file_path: &str, debug: bool, stream: bool) -> InterpretationSync { + tauri::async_runtime::block_on(run_file(source_file_path, debug, stream)) .map_err(|err| Box::::from(err.to_string())) } pub async fn run_string(source: &str, debug: bool) -> Interpretation { - run(&from_yaml_str(source)?, None, debug).await + run(&from_yaml_str(source)?, None, debug, true).await } +#[allow(dead_code)] pub async fn run_json(source: Value, debug: bool) -> Interpretation { run_string(&to_string(&source)?, debug).await } +#[allow(dead_code)] pub fn run_json_sync(source: Value, debug: bool) -> InterpretationSync { tauri::async_runtime::block_on(run_json(source, debug)) .map_err(|err| Box::::from(err.to_string())) diff --git a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs index 7d9f3c278..436aef3cc 100644 --- a/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs +++ b/pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs @@ -15,7 +15,7 @@ mod tests { #[test] fn string() -> Result<(), Box> { - let (_, messages, _) = run(&"hello".into(), None, false)?; + let (_, messages, _) = run(&"hello".into(), None, false, true)?; assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, MessageRole::User); assert_eq!(messages[0].content, "hello"); @@ -28,6 +28,7 @@ mod tests { &PdlBlock::Model(ModelBlock::new(DEFAULT_MODEL).input_str("hello").build()), None, false, + true, )?; assert_eq!(messages.len(), 1); assert_eq!(messages[0].role, MessageRole::Assistant); @@ -534,4 +535,29 @@ mod tests { assert_eq!(messages[0].content, "Bye!"); Ok(()) } + + #[test] + fn bee_1() -> Result<(), Box> { + let program = crate::compile::beeai::compile("./tests/data/bee_1.py", false)?; + let (_, messages, _) = run(&program, None, false, false)?; + assert_eq!(messages.len(), 9); + assert!( + messages.iter().any(|m| m.role == MessageRole::User + && m.content == "Provide a short history of Saint-Tropez."), + "Could not find message user:Provide a short history of Saint-Tropez. in {:?}", + messages + ); + assert!( + messages.iter().any(|m| m.role == MessageRole::System + && m.content + == "You can combine disparate information into a final coherent summary."), + "Could not find message system:Provide a short history of Saint-Tropez. in {:?}", + messages + ); + // assert!(messages.iter().any(|m| m.role == MessageRole::Assistant && m.content.contains("a renowned French Riviera town")), "Could not find message assistant:a renowned French Riviera town in {:?}", messages); + //assert_eq!(true, messages.iter().any(|m| m.role == MessageRole::Assistant && m.content.contains("I'll use the OpenMeteoTool"))); + //assert_eq!(true, messages.iter().any(|m| m.role == MessageRole::Assistant && m.content.contains("The current temperature in Saint-Tropez"))); + + Ok(()) + } } diff --git a/pdl-live-react/src-tauri/tauri.conf.json b/pdl-live-react/src-tauri/tauri.conf.json index e35d1fd32..108356584 100644 --- a/pdl-live-react/src-tauri/tauri.conf.json +++ b/pdl-live-react/src-tauri/tauri.conf.json @@ -55,6 +55,9 @@ "required": true, "takesValue": true }, + { + "name": "no-stream" + }, { "name": "debug", "short": "g" diff --git a/pdl-live-react/src-tauri/tests/data/bee_1.py b/pdl-live-react/src-tauri/tests/data/bee_1.py new file mode 100644 index 000000000..85fabf131 --- /dev/null +++ b/pdl-live-react/src-tauri/tests/data/bee_1.py @@ -0,0 +1,64 @@ +import asyncio + +from beeai_framework.backend.chat import ChatModel +from beeai_framework.tools.search.wikipedia import WikipediaTool +from beeai_framework.tools.weather.openmeteo import OpenMeteoTool +from beeai_framework.workflows.agent import AgentWorkflow, AgentWorkflowInput + + +async def main() -> None: + llm = ChatModel.from_name("ollama:granite3.2:2b") + workflow = AgentWorkflow(name="Smart assistant") + + workflow.add_agent( + name="Researcher", + role="A diligent researcher.", + instructions="You look up and provide information about a specific topic.", + tools=[WikipediaTool()], + llm=llm, + ) + + workflow.add_agent( + name="WeatherForecaster", + role="A weather reporter.", + instructions="You provide detailed weather reports.", + tools=[OpenMeteoTool()], + llm=llm, + ) + + workflow.add_agent( + name="DataSynthesizer", + role="A meticulous and creative data synthesizer", + instructions="You can combine disparate information into a final coherent summary.", + llm=llm, + ) + + location = "Saint-Tropez" + + response = await workflow.run( + inputs=[ + AgentWorkflowInput( + prompt=f"Provide a short history of {location}.", + ), + AgentWorkflowInput( + prompt=f"Provide a comprehensive weather summary for {location} today.", + expected_output="Essential weather details such as chance of rain, temperature and wind. Only report information that is available.", + ), + AgentWorkflowInput( + prompt=f"Summarize the historical and weather data for {location}.", + expected_output=f"A paragraph that describes the history of {location}, followed by the current weather conditions.", + ), + ] + ).on( + "success", + lambda data, event: print( + f"\n-> Step '{data.step}' has been completed with the following outcome.\n\n{data.state.final_answer}" + ), + ) + + print("==== Final Answer ====") + print(response.result.final_answer) + + +if __name__ == "__main__": + asyncio.run(main())