Skip to content

Commit a1123b0

Browse files
committed
streaming ollama
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 821bbc5 commit a1123b0

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

pdl-live-react/src-tauri/Cargo.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pdl-live-react/src-tauri/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ base64ct = { version = "1.7.1", features = ["alloc"] }
3535
dirs = "6.0.0"
3636
serde_norway = "0.9.42"
3737
minijinja = { version = "2.9.0", features = ["custom_syntax"] }
38-
ollama-rs = { version = "0.3.0", features = ["tokio"] }
38+
ollama-rs = { version = "0.3.0", features = ["stream"] }
3939
owo-colors = "4.2.0"
4040
rustpython-vm = "0.4.0"
4141
async-recursion = "1.1.1"
42+
tokio-stream = "0.1.17"
43+
tokio = { version = "1.44.1", features = ["io-std"] }
4244

4345
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
4446
tauri-plugin-cli = "2"

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// use ::std::cell::LazyCell;
22
use ::std::collections::HashMap;
3+
use std::sync::{Arc, Mutex};
34
// use ::std::env::current_dir;
45
use ::std::error::Error;
56
use ::std::fs::File;
@@ -8,10 +9,12 @@ use ::std::fs::File;
89
use async_recursion::async_recursion;
910
use minijinja::{syntax::SyntaxConfig, Environment};
1011
use owo_colors::OwoColorize;
12+
use tokio::io::{stdout, AsyncWriteExt};
13+
use tokio_stream::StreamExt;
1114

1215
use ollama_rs::{
1316
generation::{
14-
chat::{request::ChatMessageRequest, ChatMessage, MessageRole},
17+
chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponse, MessageRole},
1518
tools::ToolInfo,
1619
},
1720
models::ModelOptions,
@@ -288,7 +291,7 @@ impl<'a> Interpreter<'a> {
288291
pdl_model
289292
if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") =>
290293
{
291-
let mut ollama = Ollama::default();
294+
let ollama = Ollama::default();
292295
let model = if pdl_model.starts_with("ollama/") {
293296
&pdl_model[7..]
294297
} else {
@@ -313,7 +316,7 @@ impl<'a> Interpreter<'a> {
313316
Some(x) => x,
314317
None => (&ChatMessage::user("".into()), &[]),
315318
};
316-
let mut history = Vec::from(history_slice);
319+
let history = Vec::from(history_slice);
317320
if self.debug {
318321
eprintln!(
319322
"Ollama {:?} model={:?} prompt={:?} history={:?}",
@@ -327,6 +330,7 @@ impl<'a> Interpreter<'a> {
327330
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
328331
.options(options)
329332
.tools(tools);
333+
/* if we ever want non-streaming:
330334
let res = ollama
331335
.send_chat_messages_with_history(
332336
&mut history,
@@ -349,6 +353,48 @@ impl<'a> Interpreter<'a> {
349353
}
350354
// dbg!(history);
351355
Ok((vec![res.message], PdlBlock::Model(trace)))
356+
*/
357+
let mut stream = ollama
358+
.send_chat_messages_with_history_stream(
359+
Arc::new(Mutex::new(history)),
360+
req,
361+
//ollama.generate(GenerationRequest::new(model.into(), prompt),
362+
)
363+
.await?;
364+
// dbg!("Model result {:?}", &res);
365+
366+
let mut last_res: Option<ChatMessageResponse> = None;
367+
let mut response_string = String::new();
368+
let mut stdout = stdout();
369+
stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?;
370+
while let Some(Ok(res)) = stream.next().await {
371+
stdout.write_all(b"\x1b[32m").await?; // green
372+
stdout.write_all(res.message.content.as_bytes()).await?;
373+
stdout.flush().await?;
374+
stdout.write_all(b"\x1b[0m").await?; // reset color
375+
response_string += res.message.content.as_str();
376+
last_res = Some(res);
377+
}
378+
379+
let mut trace = block.clone();
380+
trace.pdl_result = Some(response_string.clone());
381+
382+
if let Some(res) = last_res {
383+
if let Some(usage) = res.final_data {
384+
trace.pdl_usage = Some(PdlUsage {
385+
prompt_tokens: usage.prompt_eval_count,
386+
prompt_nanos: usage.prompt_eval_duration,
387+
completion_tokens: usage.eval_count,
388+
completion_nanos: usage.eval_duration,
389+
});
390+
}
391+
let mut message = res.message.clone();
392+
message.content = response_string;
393+
Ok((vec![message], PdlBlock::Model(trace)))
394+
} else {
395+
Ok((vec![], PdlBlock::Model(trace)))
396+
}
397+
// dbg!(history);
352398
}
353399
_ => Err(Box::from(format!("Unsupported model {}", block.model))),
354400
}

0 commit comments

Comments
 (0)