Skip to content

Commit c596d52

Browse files
committed
feat: rust interpreter support for modelResponse, ollama-rs tooling calling, and no-stream
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 618b838 commit c596d52

File tree

5 files changed

+138
-70
lines changed

5 files changed

+138
-70
lines changed

pdl-live-react/src-tauri/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pdl-live-react/src-tauri/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@ base64ct = { version = "1.7.1", features = ["alloc"] }
3535
dirs = "6.0.0"
3636
serde_norway = "0.9.42"
3737
minijinja = { version = "2.9.0", features = ["custom_syntax"] }
38-
ollama-rs = { version = "0.3.0", features = ["stream"] }
38+
#ollama-rs = { version = "0.3.0", features = ["stream"] }
39+
ollama-rs = { git = "https://github.com/starpit/ollama-rs.git", branch = "tools-pub-7", features = ["stream"] }
3940
owo-colors = "4.2.0"
4041
rustpython-vm = "0.4.0"
4142
async-recursion = "1.1.1"
4243
tokio-stream = "0.1.17"
4344
tokio = { version = "1.44.1", features = ["io-std"] }
4445
indexmap = { version = "2.9.0", features = ["serde"] }
4546
rustpython-stdlib = { version = "0.4.0", features = ["zlib"] }
47+
schemars = "0.8.22"
4648

4749
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
4850
tauri-plugin-cli = "2"

pdl-live-react/src-tauri/src/cli.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ pub fn setup(app: &mut tauri::App) -> Result<bool, Box<dyn ::std::error::Error>>
6060
.and_then(|a| a.value.as_bool())
6161
.or(Some(false))
6262
== Some(true),
63+
subcommand_args
64+
.get("no-stream")
65+
.and_then(|a| a.value.as_bool())
66+
.or(Some(false))
67+
== Some(false),
6368
)
6469
.and_then(|_trace| Ok(true)),
6570
"run" => run_pdl_program(

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 125 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ use tokio_stream::StreamExt;
1515
use ollama_rs::{
1616
generation::{
1717
chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponse, MessageRole},
18-
tools::ToolInfo,
18+
tools::{ToolFunctionInfo, ToolInfo, ToolType},
1919
},
2020
models::ModelOptions,
2121
Ollama,
2222
};
2323

24-
use serde_json::{from_str, to_string, Value};
24+
use serde_json::{from_str, json, to_string, Value};
2525
use serde_norway::{from_reader, from_str as from_yaml_str};
2626

2727
use crate::pdl::ast::{
@@ -45,6 +45,7 @@ struct Interpreter<'a> {
4545
scope: Vec<Scope>,
4646
debug: bool,
4747
emit: bool,
48+
stream: bool,
4849
}
4950

5051
impl<'a> Interpreter<'a> {
@@ -67,6 +68,7 @@ impl<'a> Interpreter<'a> {
6768
scope: vec![Scope::new()],
6869
debug: false,
6970
emit: true,
71+
stream: true,
7072
}
7173
}
7274

@@ -76,14 +78,13 @@ impl<'a> Interpreter<'a> {
7678
context: Context,
7779
emit: bool,
7880
) -> Interpretation {
79-
if self.debug {
81+
/* if self.debug {
8082
if let Some(scope) = self.scope.last() {
8183
if scope.len() > 0 {
8284
eprintln!("Run with Scope {:?}", scope);
8385
}
8486
}
85-
}
86-
87+
} */
8788
let prior_emit = self.emit;
8889
self.emit = emit;
8990

@@ -118,7 +119,9 @@ impl<'a> Interpreter<'a> {
118119
}?;
119120

120121
if match program {
121-
PdlBlock::Call(_) | PdlBlock::Model(_) => false,
122+
PdlBlock::Text(_) | PdlBlock::LastOf(_) | PdlBlock::Call(_) | PdlBlock::Model(_) => {
123+
false
124+
}
122125
_ => self.emit,
123126
} {
124127
println!("{}", pretty_print(&messages));
@@ -150,7 +153,7 @@ impl<'a> Interpreter<'a> {
150153
let backup = result.clone();
151154
Ok(from_str(&result).unwrap_or_else(|err| {
152155
if self.debug {
153-
eprintln!("Treating as plain string {}", &result);
156+
eprintln!("Treating as plain string {}", result);
154157
eprintln!("... due to {}", err);
155158
}
156159
backup.into()
@@ -332,7 +335,10 @@ impl<'a> Interpreter<'a> {
332335

333336
self.run(&c.function.return_, context.clone()).await
334337
}
335-
_ => Err(Box::from(format!("call of non-function {:?}", &block.call))),
338+
x => Err(Box::from(format!(
339+
"call of non-function {:?}->{:?}",
340+
block.call, x
341+
))),
336342
};
337343

338344
if let Some(_) = block.args {
@@ -438,10 +444,36 @@ impl<'a> Interpreter<'a> {
438444
0.0
439445
};
440446

441-
let tools = if let Some(Value::Array(_tools)) = parameters.get(&"tools".to_string()) {
442-
// TODO
443-
//tools.into_iter().map(|tool| function!()).collect()
444-
vec![]
447+
let tools = if let Some(Value::Array(tools)) = parameters.get("tools") {
448+
tools
449+
.into_iter()
450+
.filter_map(|tool| tool.get("function"))
451+
.filter_map(|tool| {
452+
//from_str(&to_string(tool)?)
453+
match (
454+
tool.get("name"),
455+
tool.get("description"),
456+
tool.get("parameters"),
457+
) {
458+
(
459+
Some(Value::String(name)),
460+
Some(Value::String(description)),
461+
Some(Value::Object(parameters)),
462+
) => Some(ToolInfo {
463+
tool_type: ToolType::Function,
464+
function: ToolFunctionInfo {
465+
name: name.to_string(),
466+
description: description.to_string(),
467+
parameters: schemars::schema_for_value!(parameters),
468+
},
469+
}),
470+
_ => {
471+
eprintln!("Error: tools do not satisfy schema {:?}", tool);
472+
None
473+
}
474+
}
475+
})
476+
.collect()
445477
} else {
446478
vec![]
447479
};
@@ -517,7 +549,7 @@ impl<'a> Interpreter<'a> {
517549
pdl_model
518550
if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") =>
519551
{
520-
let ollama = Ollama::default();
552+
let mut ollama = Ollama::default();
521553
let model = if pdl_model.starts_with("ollama/") {
522554
&pdl_model[7..]
523555
} else {
@@ -526,7 +558,8 @@ impl<'a> Interpreter<'a> {
526558

527559
let (options, tools) = self.to_ollama_model_options(&block.parameters);
528560
if self.debug {
529-
println!("Model options {:?}", options);
561+
eprintln!("Model options {:?} {:?}", block.description, options);
562+
eprintln!("Model tools {:?} {:?}", block.description, tools);
530563
}
531564

532565
let input_messages = match &block.input {
@@ -542,7 +575,7 @@ impl<'a> Interpreter<'a> {
542575
Some(x) => x,
543576
None => (&ChatMessage::user("".into()), &[]),
544577
};
545-
let history = Vec::from(history_slice);
578+
let mut history = Vec::from(history_slice);
546579
if self.debug {
547580
eprintln!(
548581
"Ollama {:?} model={:?} prompt={:?} history={:?}",
@@ -560,52 +593,62 @@ impl<'a> Interpreter<'a> {
560593
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
561594
.options(options)
562595
.tools(tools);
563-
/* if we ever want non-streaming:
564-
let res = ollama
565-
.send_chat_messages_with_history(
566-
&mut history,
567-
req,
568-
//ollama.generate(GenerationRequest::new(model.into(), prompt),
569-
)
570-
.await?;
571-
// dbg!("Model result {:?}", &res);
572596

573-
let mut trace = block.clone();
574-
trace.pdl_result = Some(res.message.content.clone());
575-
576-
if let Some(usage) = res.final_data {
577-
trace.pdl_usage = Some(PdlUsage {
578-
prompt_tokens: usage.prompt_eval_count,
579-
prompt_nanos: usage.prompt_eval_duration,
580-
completion_tokens: usage.eval_count,
581-
completion_nanos: usage.eval_duration,
582-
});
583-
}
584-
// dbg!(history);
585-
Ok((vec![res.message], PdlBlock::Model(trace)))
586-
*/
587-
let mut stream = ollama
588-
.send_chat_messages_with_history_stream(
589-
Arc::new(Mutex::new(history)),
590-
req,
591-
//ollama.generate(GenerationRequest::new(model.into(), prompt),
592-
)
593-
.await?;
594-
// dbg!("Model result {:?}", &res);
595-
596-
let mut last_res: Option<ChatMessageResponse> = None;
597-
let mut response_string = String::new();
598-
let mut stdout = stdout();
599-
stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?;
600-
while let Some(Ok(res)) = stream.next().await {
601-
stdout.write_all(b"\x1b[32m").await?; // green
602-
stdout.write_all(res.message.content.as_bytes()).await?;
603-
stdout.flush().await?;
604-
stdout.write_all(b"\x1b[0m").await?; // reset color
605-
response_string += res.message.content.as_str();
606-
last_res = Some(res);
597+
let (last_res, response_string) = if !self.stream {
598+
let res = ollama
599+
.send_chat_messages_with_history(&mut history, req)
600+
.await?;
601+
let response_string = res.message.content.clone();
602+
print!("{}", response_string);
603+
(Some(res), response_string)
604+
} else {
605+
let mut stream = ollama
606+
.send_chat_messages_with_history_stream(
607+
Arc::new(Mutex::new(history)),
608+
req,
609+
//ollama.generate(GenerationRequest::new(model.into(), prompt),
610+
)
611+
.await?;
612+
// dbg!("Model result {:?}", &res);
613+
614+
let emit = if let Some(_) = &block.model_response {
615+
false
616+
} else {
617+
true
618+
};
619+
620+
let mut last_res: Option<ChatMessageResponse> = None;
621+
let mut response_string = String::new();
622+
let mut stdout = stdout();
623+
if emit {
624+
stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?;
625+
}
626+
while let Some(Ok(res)) = stream.next().await {
627+
if emit {
628+
stdout.write_all(b"\x1b[32m").await?; // green
629+
stdout.write_all(res.message.content.as_bytes()).await?;
630+
stdout.flush().await?;
631+
stdout.write_all(b"\x1b[0m").await?; // reset color
632+
}
633+
response_string += res.message.content.as_str();
634+
last_res = Some(res);
635+
}
636+
if emit {
637+
stdout.write_all(b"\n").await?;
638+
}
639+
640+
(last_res, response_string)
641+
};
642+
643+
if let Some(_) = &block.model_response {
644+
if let Some(ref res) = last_res {
645+
self.def(
646+
&block.model_response,
647+
&self.resultify_as_litellm(&from_str(&to_string(&res)?)?),
648+
&None,
649+
)?;
650+
}
607651
}
608-
stdout.write_all(b"\n").await?;
609652

610653
let mut trace = block.clone();
611654
trace.pdl_result = Some(response_string.clone());
@@ -653,6 +696,15 @@ impl<'a> Interpreter<'a> {
653696
}
654697
}
655698

699+
/// Transform a JSON Value into a PdlResult object that is compatible with litellm's model response schema
700+
fn resultify_as_litellm(&self, value: &Value) -> PdlResult {
701+
self.resultify(&json!({
702+
"choices": [
703+
value
704+
]
705+
}))
706+
}
707+
656708
/// Run a PdlBlock::Data
657709
async fn run_data(&mut self, block: &DataBlock, _context: Context) -> Interpretation {
658710
if self.debug {
@@ -821,7 +873,7 @@ impl<'a> Interpreter<'a> {
821873
while let Some(block) = iter.next() {
822874
// run each element of the Text block
823875
let (this_result, this_messages, trace) =
824-
self.run_quiet(&block, input_messages.clone()).await?;
876+
self.run(&block, input_messages.clone()).await?;
825877
input_messages.extend(this_messages.clone());
826878
output_results.push(this_result);
827879

@@ -918,17 +970,23 @@ impl<'a> Interpreter<'a> {
918970
}
919971
}
920972

921-
pub async fn run(program: &PdlBlock, cwd: Option<PathBuf>, debug: bool) -> Interpretation {
973+
pub async fn run(
974+
program: &PdlBlock,
975+
cwd: Option<PathBuf>,
976+
debug: bool,
977+
stream: bool,
978+
) -> Interpretation {
922979
let mut interpreter = Interpreter::new();
923980
interpreter.debug = debug;
981+
interpreter.stream = stream;
924982
if let Some(cwd) = cwd {
925983
interpreter.cwd = cwd
926984
};
927985
interpreter.run(&program, vec![]).await
928986
}
929987

930988
pub fn run_sync(program: &PdlBlock, cwd: Option<PathBuf>, debug: bool) -> InterpretationSync {
931-
tauri::async_runtime::block_on(run(program, cwd, debug))
989+
tauri::async_runtime::block_on(run(program, cwd, debug, true))
932990
.map_err(|err| Box::<dyn Error>::from(err.to_string()))
933991
}
934992

@@ -938,22 +996,22 @@ pub fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
938996
.map_err(|err| Box::<dyn Error + Send + Sync>::from(err.to_string()))
939997
}
940998

941-
pub async fn run_file(source_file_path: &str, debug: bool) -> Interpretation {
999+
pub async fn run_file(source_file_path: &str, debug: bool, stream: bool) -> Interpretation {
9421000
let path = PathBuf::from(source_file_path);
9431001
let cwd = path.parent().and_then(|cwd| Some(cwd.to_path_buf()));
9441002
let program = parse_file(&path)?;
9451003

9461004
crate::pdl::pull::pull_if_needed(&program).await?;
947-
run(&program, cwd, debug).await
1005+
run(&program, cwd, debug, stream).await
9481006
}
9491007

950-
pub fn run_file_sync(source_file_path: &str, debug: bool) -> InterpretationSync {
951-
tauri::async_runtime::block_on(run_file(source_file_path, debug))
1008+
pub fn run_file_sync(source_file_path: &str, debug: bool, stream: bool) -> InterpretationSync {
1009+
tauri::async_runtime::block_on(run_file(source_file_path, debug, stream))
9521010
.map_err(|err| Box::<dyn Error>::from(err.to_string()))
9531011
}
9541012

9551013
pub async fn run_string(source: &str, debug: bool) -> Interpretation {
956-
run(&from_yaml_str(source)?, None, debug).await
1014+
run(&from_yaml_str(source)?, None, debug, true).await
9571015
}
9581016

9591017
pub async fn run_json(source: Value, debug: bool) -> Interpretation {

pdl-live-react/src-tauri/tauri.conf.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
"required": true,
5656
"takesValue": true
5757
},
58+
{
59+
"name": "no-stream"
60+
},
5861
{
5962
"name": "debug",
6063
"short": "g"

0 commit comments

Comments
 (0)