Skip to content

Commit 3e3e31c

Browse files
committed
feat: port rust model pull logic to use rust AST
Signed-off-by: Nick Mitchell <[email protected]>
1 parent e29c2bb commit 3e3e31c

File tree

4 files changed

+71
-61
lines changed

4 files changed

+71
-61
lines changed
Lines changed: 33 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
use yaml_rust2::Yaml;
1+
use crate::pdl::ast::PdlBlock;
22

33
/// Extract models referenced by the programs
4-
pub fn extract_models(programs: Vec<Yaml>) -> Vec<String> {
5-
extract_values(programs, "model")
4+
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
5+
extract_values(program, "model")
66
}
77

88
/// Take a list of Yaml fragments and produce a vector of the string-valued entries of the given field
9-
pub fn extract_values(programs: Vec<Yaml>, field: &str) -> Vec<String> {
10-
let mut values = programs
11-
.into_iter()
12-
.flat_map(|p| extract_one_values(p, field))
13-
.collect::<Vec<String>>();
9+
pub fn extract_values(program: &PdlBlock, field: &str) -> Vec<String> {
10+
let mut values = vec![];
11+
extract_values_iter(program, field, &mut values);
1412

1513
// A single program may specify the same model more than once. Dedup!
1614
values.sort();
@@ -20,38 +18,37 @@ pub fn extract_values(programs: Vec<Yaml>, field: &str) -> Vec<String> {
2018
}
2119

2220
/// Take one Yaml fragment and produce a vector of the string-valued entries of the given field
23-
fn extract_one_values(program: Yaml, field: &str) -> Vec<String> {
24-
let mut values: Vec<String> = Vec::new();
25-
21+
fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>) {
2622
match program {
27-
Yaml::Hash(h) => {
28-
for (key, val) in h {
29-
match key {
30-
Yaml::String(f) if f == field => match &val {
31-
Yaml::String(m) => {
32-
values.push(m.to_string());
33-
}
34-
_ => {}
35-
},
36-
_ => {}
37-
}
38-
39-
for m in extract_one_values(val, field) {
40-
values.push(m)
41-
}
42-
}
23+
PdlBlock::Model(b) => values.push(b.model.clone()),
24+
PdlBlock::Repeat(b) => {
25+
extract_values_iter(&b.repeat, field, values);
4326
}
44-
45-
Yaml::Array(a) => {
46-
for val in a {
47-
for m in extract_one_values(val, field) {
48-
values.push(m)
49-
}
27+
PdlBlock::Message(b) => {
28+
extract_values_iter(&b.content, field, values);
29+
}
30+
PdlBlock::Array(b) => b
31+
.array
32+
.iter()
33+
.for_each(|p| extract_values_iter(p, field, values)),
34+
PdlBlock::Text(b) => b
35+
.text
36+
.iter()
37+
.for_each(|p| extract_values_iter(p, field, values)),
38+
PdlBlock::LastOf(b) => b
39+
.last_of
40+
.iter()
41+
.for_each(|p| extract_values_iter(p, field, values)),
42+
PdlBlock::If(b) => {
43+
extract_values_iter(&b.then, field, values);
44+
if let Some(else_) = &b.else_ {
45+
extract_values_iter(else_, field, values);
5046
}
5147
}
52-
48+
PdlBlock::Object(b) => b
49+
.object
50+
.values()
51+
.for_each(|p| extract_values_iter(p, field, values)),
5352
_ => {}
5453
}
55-
56-
values
5754
}

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -466,23 +466,27 @@ impl<'a> Interpreter<'a> {
466466
let scope = vm.new_scope_with_builtins();
467467

468468
// TODO vm.new_syntax_error(&err, Some(block.code.as_str()))
469-
let code_obj = match vm
470-
.compile(
471-
block.code.as_str(),
472-
vm::compiler::Mode::Exec,
473-
"<embedded>".to_owned(),
474-
) {
475-
Ok(x) => Ok(x),
476-
Err(exc) => Err(Box::<dyn Error + Send + Sync>::from(format!("Syntax error in Python code {:?}", exc))),
477-
}?;
469+
let code_obj = match vm.compile(
470+
block.code.as_str(),
471+
vm::compiler::Mode::Exec,
472+
"<embedded>".to_owned(),
473+
) {
474+
Ok(x) => Ok(x),
475+
Err(exc) => Err(Box::<dyn Error + Send + Sync>::from(format!(
476+
"Syntax error in Python code {:?}",
477+
exc
478+
))),
479+
}?;
478480

479481
// TODO vm.print_exception(exc);
480482
match vm.run_code_obj(code_obj, scope.clone()) {
481483
Ok(_) => Ok(()),
482484
Err(exc) => {
483485
vm.print_exception(exc);
484-
Err(Box::<dyn Error + Send + Sync>::from("Error executing Python code"))
485-
},
486+
Err(Box::<dyn Error + Send + Sync>::from(
487+
"Error executing Python code",
488+
))
489+
}
486490
}?;
487491

488492
match scope.globals.get_item("result", vm) {
@@ -491,8 +495,10 @@ impl<'a> Interpreter<'a> {
491495
Ok(x) => Ok(x),
492496
Err(exc) => {
493497
vm.print_exception(exc);
494-
Err(Box::<dyn Error + Send + Sync>::from("Unable to stringify Python 'result' value"))
495-
},
498+
Err(Box::<dyn Error + Send + Sync>::from(
499+
"Unable to stringify Python 'result' value",
500+
))
501+
}
496502
}?;
497503
let messages = vec![ChatMessage::user(result_string.as_str().to_string())];
498504
let trace = PdlBlock::PythonCode(block.clone());
@@ -927,7 +933,7 @@ pub fn run_sync(program: &PdlBlock, cwd: Option<PathBuf>, debug: bool) -> Interp
927933
}
928934

929935
/// Read in a file from disk and parse it as a PDL program
930-
fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
936+
pub fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
931937
from_reader(File::open(path)?)
932938
.map_err(|err| Box::<dyn Error + Send + Sync>::from(err.to_string()))
933939
}
@@ -937,6 +943,7 @@ pub async fn run_file(source_file_path: &str, debug: bool) -> Interpretation {
937943
let cwd = path.parent().and_then(|cwd| Some(cwd.to_path_buf()));
938944
let program = parse_file(&path)?;
939945

946+
crate::pdl::pull::pull_if_needed(&program).await?;
940947
run(&program, cwd, debug).await
941948
}
942949

pdl-live-react/src-tauri/src/pdl/pull.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1-
use ::std::io::{Error, ErrorKind};
1+
use ::std::io::Error;
22

33
use duct::cmd;
44
use rayon::prelude::*;
5-
use yaml_rust2::{Yaml, YamlLoader};
65

6+
use crate::pdl::ast::PdlBlock;
77
use crate::pdl::extract;
8+
use crate::pdl::interpreter::parse_file;
89

9-
/// Read the given filesystem path and produce a potentially multi-document Yaml
10-
fn from_path(path: &str) -> Result<Vec<Yaml>, Error> {
11-
let content = std::fs::read_to_string(path)?;
12-
YamlLoader::load_from_str(&content).map_err(|e| Error::new(ErrorKind::Other, e.to_string()))
10+
pub async fn pull_if_needed_from_path(
11+
source_file_path: &str,
12+
) -> Result<(), Box<dyn ::std::error::Error + Send + Sync>> {
13+
let program = parse_file(&::std::path::PathBuf::from(source_file_path))?;
14+
pull_if_needed(&program)
15+
.await
16+
.map_err(|e| Box::from(e.to_string()))
1317
}
1418

1519
/// Pull models (in parallel) from the PDL program in the given filepath.
16-
pub async fn pull_if_needed(path: &str) -> Result<(), Error> {
17-
extract::extract_models(from_path(path)?)
20+
pub async fn pull_if_needed(program: &PdlBlock) -> Result<(), Error> {
21+
extract::extract_models(program)
1822
.into_par_iter()
1923
.try_for_each(|model| match model {
2024
m if model.starts_with("ollama/") => ollama_pull_if_needed(&m[7..]),

pdl-live-react/src-tauri/src/pdl/run.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use duct::cmd;
33
use futures::executor::block_on;
44

55
use crate::pdl::pip::pip_install_if_needed;
6-
use crate::pdl::pull::pull_if_needed;
6+
use crate::pdl::pull::pull_if_needed_from_path;
77
use crate::pdl::requirements::PDL_INTERPRETER;
88

99
#[cfg(desktop)]
@@ -19,11 +19,13 @@ pub fn run_pdl_program(
1919
);
2020

2121
// async the model pull and pip installs
22-
let pull_future = pull_if_needed(&source_file_path);
22+
let pull_future = pull_if_needed_from_path(&source_file_path);
2323
let bin_path_future = pip_install_if_needed(&PDL_INTERPRETER);
2424

2525
// wait for any model pulls to finish
26-
block_on(pull_future)?;
26+
if let Err(e) = block_on(pull_future) {
27+
return Err(e);
28+
}
2729

2830
// wait for any pip installs to finish
2931
let bin_path = block_on(bin_path_future)?;

0 commit comments

Comments
 (0)