Skip to content

Commit ace9a85

Browse files
committed
avoid block_on on model calls
Signed-off-by: Nick Mitchell <[email protected]>
1 parent a49a59c commit ace9a85

File tree

7 files changed

+87
-47
lines changed

7 files changed

+87
-47
lines changed

pdl-live-react/src-tauri/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pdl-live-react/src-tauri/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ minijinja = { version = "2.9.0", features = ["custom_syntax"] }
3838
ollama-rs = { version = "0.3.0", features = ["tokio"] }
3939
owo-colors = "4.2.0"
4040
rustpython-vm = "0.4.0"
41+
async-recursion = "1.1.1"
4142

4243
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
4344
tauri-plugin-cli = "2"

pdl-live-react/src-tauri/src/cli.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use urlencoding::encode;
55

66
use crate::compile;
77
use crate::gui::new_window;
8-
use crate::pdl::interpreter::run_file as runr;
8+
use crate::pdl::interpreter::run_file_sync as runr;
99
use crate::pdl::run::run_pdl_program;
1010

1111
#[cfg(desktop)]

pdl-live-react/src-tauri/src/commands/interpreter.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use crate::pdl::interpreter::{pretty_print, run_string};
22

33
#[tauri::command]
44
pub async fn run_pdl_program(program: String, debug: bool) -> Result<String, String> {
5-
let (messages, _) = run_string(&program, debug).map_err(|err| err.to_string())?;
5+
let (messages, _) = run_string(&program, debug)
6+
.await
7+
.map_err(|err| err.to_string())?;
68

79
Ok(pretty_print(&messages))
810
}

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ::std::error::Error;
55
use ::std::fs::File;
66
// use ::std::path::PathBuf;
77

8+
use async_recursion::async_recursion;
89
use minijinja::{syntax::SyntaxConfig, Environment};
910
use owo_colors::OwoColorize;
1011

@@ -27,7 +28,8 @@ use crate::pdl::ast::{
2728

2829
type Context = Vec<ChatMessage>;
2930
type Scope = HashMap<String, Value>;
30-
type Interpretation = Result<(Context, PdlBlock), Box<dyn Error>>;
31+
type Interpretation = Result<(Context, PdlBlock), Box<dyn Error + Send + Sync>>;
32+
type InterpretationSync = Result<(Context, PdlBlock), Box<dyn Error>>;
3133

3234
struct Interpreter<'a> {
3335
// batch: u32,
@@ -65,7 +67,7 @@ impl<'a> Interpreter<'a> {
6567
}
6668
}
6769

68-
fn run_with_emit(
70+
async fn run_with_emit(
6971
&mut self,
7072
program: &PdlBlock,
7173
context: Context,
@@ -75,12 +77,12 @@ impl<'a> Interpreter<'a> {
7577
self.emit = emit;
7678

7779
let (messages, trace) = match program {
78-
PdlBlock::String(s) => self.run_string(s, context),
79-
PdlBlock::Call(block) => self.run_call(block, context),
80-
PdlBlock::PythonCode(block) => self.run_python_code(block, context),
81-
PdlBlock::Model(block) => self.run_model(block, context),
82-
PdlBlock::Repeat(block) => self.run_repeat(block, context),
83-
PdlBlock::Text(block) => self.run_text(block, context),
80+
PdlBlock::String(s) => self.run_string(s, context).await,
81+
PdlBlock::Call(block) => self.run_call(block, context).await,
82+
PdlBlock::PythonCode(block) => self.run_python_code(block, context).await,
83+
PdlBlock::Model(block) => self.run_model(block, context).await,
84+
PdlBlock::Repeat(block) => self.run_repeat(block, context).await,
85+
PdlBlock::Text(block) => self.run_text(block, context).await,
8486
_ => Err(Box::from(format!("Unsupported block {:?}", program))),
8587
}?;
8688

@@ -96,19 +98,21 @@ impl<'a> Interpreter<'a> {
9698
Ok((messages, trace))
9799
}
98100

99-
fn run_quiet(&mut self, program: &PdlBlock, context: Context) -> Interpretation {
100-
self.run_with_emit(program, context, false)
101+
#[async_recursion]
102+
async fn run_quiet(&mut self, program: &PdlBlock, context: Context) -> Interpretation {
103+
self.run_with_emit(program, context, false).await
101104
}
102105

103-
fn run(&mut self, program: &PdlBlock, context: Context) -> Interpretation {
104-
self.run_with_emit(program, context, self.emit)
106+
#[async_recursion]
107+
async fn run(&mut self, program: &PdlBlock, context: Context) -> Interpretation {
108+
self.run_with_emit(program, context, self.emit).await
105109
}
106110

107111
// Evaluate as a Jinja2 expression
108112
fn eval<T: serde::de::DeserializeOwned + ::std::convert::From<String>>(
109113
&self,
110114
expr: &String,
111-
) -> Result<T, Box<dyn Error>> {
115+
) -> Result<T, Box<dyn Error + Send + Sync>> {
112116
let result = self
113117
.jinja_env
114118
.render_str(expr.as_str(), self.scope.last().unwrap_or(&HashMap::new()))?;
@@ -127,7 +131,7 @@ impl<'a> Interpreter<'a> {
127131
}
128132

129133
// Run a PdlBlock::String
130-
fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
134+
async fn run_string(&self, msg: &String, _context: Context) -> Interpretation {
131135
let trace = self.eval::<PdlBlock>(msg)?;
132136
if self.debug {
133137
eprintln!("String {} -> {:?}", msg, trace);
@@ -142,7 +146,7 @@ impl<'a> Interpreter<'a> {
142146
}
143147

144148
// Run a PdlBlock::Call
145-
fn run_call(&mut self, block: &PdlCallBlock, context: Context) -> Interpretation {
149+
async fn run_call(&mut self, block: &PdlCallBlock, context: Context) -> Interpretation {
146150
if self.debug {
147151
eprintln!("Call {:?}({:?})", block.call, block.args);
148152
}
@@ -154,7 +158,7 @@ impl<'a> Interpreter<'a> {
154158
// args was a string that eval'd to an Object
155159
Value::Object(m) => Ok(Some(self.to_pdl(&m))),
156160
// args was a string that eval'd to something we don't understand
157-
y => Err(Box::<dyn Error>::from(format!(
161+
y => Err(Box::<dyn Error + Send + Sync>::from(format!(
158162
"Invalid arguments to call {:?}",
159163
y
160164
))),
@@ -170,7 +174,7 @@ impl<'a> Interpreter<'a> {
170174
self.extend_scope_with_map(&args);
171175

172176
let res = match self.eval::<PdlBlock>(&block.call)? {
173-
PdlBlock::Function(f) => self.run(&f.return_, context.clone()),
177+
PdlBlock::Function(f) => self.run(&f.return_, context.clone()).await,
174178
_ => Err(Box::from(format!("call of non-function {:?}", &block.call))),
175179
};
176180
self.scope.pop();
@@ -230,7 +234,11 @@ impl<'a> Interpreter<'a> {
230234
}
231235

232236
// Run a PdlBlock::PythonCode
233-
fn run_python_code(&mut self, block: &PdlPythonCodeBlock, context: Context) -> Interpretation {
237+
async fn run_python_code(
238+
&mut self,
239+
block: &PdlPythonCodeBlock,
240+
context: Context,
241+
) -> Interpretation {
234242
use rustpython_vm as vm;
235243
vm::Interpreter::without_stdlib(Default::default()).enter(|vm| -> Interpretation {
236244
let scope = vm.new_scope_with_builtins();
@@ -275,7 +283,7 @@ impl<'a> Interpreter<'a> {
275283
}
276284

277285
// Run a PdlBlock::Model
278-
fn run_model(&mut self, block: &PdlModelBlock, context: Context) -> Interpretation {
286+
async fn run_model(&mut self, block: &PdlModelBlock, context: Context) -> Interpretation {
279287
match &block.model {
280288
pdl_model
281289
if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") =>
@@ -295,7 +303,7 @@ impl<'a> Interpreter<'a> {
295303
let messages = match &block.input {
296304
Some(input) => {
297305
// TODO ignoring trace
298-
let (messages, _trace) = self.run_quiet(&*input, context)?;
306+
let (messages, _trace) = self.run_quiet(&*input, context).await?;
299307
messages
300308
}
301309
None => context,
@@ -310,24 +318,27 @@ impl<'a> Interpreter<'a> {
310318
eprintln!(
311319
"Ollama {:?} model={:?} prompt={:?} history={:?}",
312320
block.description.clone().unwrap_or("".into()),
313-
&block.model,
314-
&prompt,
315-
&history
321+
block.model,
322+
prompt,
323+
history
316324
);
317325
}
318326

319327
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
320328
.options(options)
321329
.tools(tools);
322-
let res = /*self.rt.*/tauri::async_runtime::block_on(ollama.send_chat_messages_with_history(
323-
&mut history,
324-
req,
325-
//ollama.generate(GenerationRequest::new(model.into(), prompt),
326-
))?;
330+
let res = ollama
331+
.send_chat_messages_with_history(
332+
&mut history,
333+
req,
334+
//ollama.generate(GenerationRequest::new(model.into(), prompt),
335+
)
336+
.await?;
327337
// dbg!("Model result {:?}", &res);
328338

329339
let mut trace = block.clone();
330340
trace.pdl_result = Some(res.message.content.clone());
341+
331342
if let Some(usage) = res.final_data {
332343
trace.pdl_usage = Some(PdlUsage {
333344
prompt_tokens: usage.prompt_eval_count,
@@ -344,7 +355,7 @@ impl<'a> Interpreter<'a> {
344355
}
345356

346357
// Run a PdlBlock::Repeat
347-
fn run_repeat(&mut self, block: &PdlRepeatBlock, _context: Context) -> Interpretation {
358+
async fn run_repeat(&mut self, block: &PdlRepeatBlock, _context: Context) -> Interpretation {
348359
let for_ = block
349360
.for_
350361
.iter()
@@ -368,7 +379,11 @@ impl<'a> Interpreter<'a> {
368379
}
369380
}
370381

371-
fn parse_result(&self, parser: &PdlParser, result: &String) -> Result<Value, Box<dyn Error>> {
382+
fn parse_result(
383+
&self,
384+
parser: &PdlParser,
385+
result: &String,
386+
) -> Result<Value, Box<dyn Error + Send + Sync>> {
372387
match parser {
373388
PdlParser::Json => Ok(from_str(result)?),
374389
}
@@ -395,7 +410,7 @@ impl<'a> Interpreter<'a> {
395410
}
396411

397412
// Run a PdlBlock::Text
398-
fn run_text(&mut self, block: &PdlTextBlock, context: Context) -> Interpretation {
413+
async fn run_text(&mut self, block: &PdlTextBlock, context: Context) -> Interpretation {
399414
if self.debug {
400415
eprintln!(
401416
"Text {:?}",
@@ -411,14 +426,14 @@ impl<'a> Interpreter<'a> {
411426
let mut output_blocks = vec![];
412427

413428
self.extend_scope_with_map(&block.defs);
414-
block.text.iter().try_for_each(|block| {
429+
let mut iter = block.text.iter();
430+
while let Some(block) = iter.next() {
415431
// run each element of the Text block
416-
let (this_messages, trace) = self.run(&block, input_messages.clone())?;
432+
let (this_messages, trace) = self.run(&block, input_messages.clone()).await?;
417433
input_messages.extend(this_messages.clone());
418434
output_messages.extend(this_messages);
419435
output_blocks.push(trace);
420-
Ok::<(), Box<dyn Error>>(())
421-
})?;
436+
}
422437
self.scope.pop();
423438

424439
let mut trace = block.clone();
@@ -455,22 +470,37 @@ impl<'a> Interpreter<'a> {
455470
}
456471
}
457472

458-
pub fn run(program: &PdlBlock, debug: bool) -> Interpretation {
473+
pub async fn run(program: &PdlBlock, debug: bool) -> Interpretation {
459474
let mut interpreter = Interpreter::new();
460475
interpreter.debug = debug;
461-
interpreter.run(&program, vec![])
476+
interpreter.run(&program, vec![]).await
477+
}
478+
479+
pub fn run_sync(program: &PdlBlock, debug: bool) -> InterpretationSync {
480+
tauri::async_runtime::block_on(run(program, debug))
481+
.map_err(|err| Box::<dyn ::std::error::Error>::from(err.to_string()))
482+
}
483+
484+
pub async fn run_file(source_file_path: &str, debug: bool) -> Interpretation {
485+
run(&from_reader(File::open(source_file_path)?)?, debug).await
486+
}
487+
488+
pub fn run_file_sync(source_file_path: &str, debug: bool) -> InterpretationSync {
489+
tauri::async_runtime::block_on(run_file(source_file_path, debug))
490+
.map_err(|err| Box::<dyn ::std::error::Error>::from(err.to_string()))
462491
}
463492

464-
pub fn run_file(source_file_path: &str, debug: bool) -> Interpretation {
465-
run(&from_reader(File::open(source_file_path)?)?, debug)
493+
pub async fn run_string(source: &str, debug: bool) -> Interpretation {
494+
run(&from_yaml_str(source)?, debug).await
466495
}
467496

468-
pub fn run_string(source: &str, debug: bool) -> Interpretation {
469-
run(&from_yaml_str(source)?, debug)
497+
pub async fn run_json(source: Value, debug: bool) -> Interpretation {
498+
run_string(&to_string(&source)?, debug).await
470499
}
471500

472-
pub fn run_json(source: Value, debug: bool) -> Interpretation {
473-
run_string(&to_string(&source)?, debug)
501+
pub fn run_json_sync(source: Value, debug: bool) -> InterpretationSync {
502+
tauri::async_runtime::block_on(run_json(source, debug))
503+
.map_err(|err| Box::<dyn ::std::error::Error>::from(err.to_string()))
474504
}
475505

476506
pub fn pretty_print(messages: &Vec<ChatMessage>) -> String {

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mod tests {
66

77
use crate::pdl::{
88
ast::{PdlBlock, PdlModelBlock, PdlParser, PdlTextBlock},
9-
interpreter::{run, run_json},
9+
interpreter::{run_json_sync as run_json, run_sync as run},
1010
};
1111

1212
use ollama_rs::generation::chat::MessageRole;

pdl-live-react/src/page/Run.tsx

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const initialInput = `text:
2828
def: foo
2929
- \${ foo.key }`
3030
export default function Run() {
31+
const [running, setRunning] = useState(false)
3132
const [input, setInput] = useState(initialInput)
3233
const [_error, setError] = useState(false)
3334

@@ -69,6 +70,7 @@ export default function Run() {
6970

7071
const run = async () => {
7172
try {
73+
setRunning(true)
7274
term?.reset()
7375
const result = await invoke("run_pdl_program", {
7476
program: input,
@@ -79,6 +81,8 @@ export default function Run() {
7981
} catch (err) {
8082
term?.write(String(err))
8183
setError(true)
84+
} finally {
85+
setRunning(false)
8286
}
8387
}
8488

@@ -132,7 +136,9 @@ export default function Run() {
132136
<Toolbar>
133137
<ToolbarContent>
134138
<ToolbarItem>
135-
<Button onClick={run}>Run</Button>
139+
<Button onClick={run} isLoading={running}>
140+
Run
141+
</Button>
136142
</ToolbarItem>
137143
</ToolbarContent>
138144
</Toolbar>

0 commit comments

Comments
 (0)