Skip to content

Commit 316a31b

Browse files
committed
fix: in rust ast, allow ModelBlock model to be an expr
Signed-off-by: Nick Mitchell <[email protected]>
1 parent a5400eb commit 316a31b

File tree

4 files changed

+207
-124
lines changed

4 files changed

+207
-124
lines changed

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ pub struct PdlUsage {
274274
#[serde(tag = "kind", rename = "model")]
275275
#[builder(setter(into, strip_option), default)]
276276
pub struct ModelBlock {
277-
pub model: String,
277+
pub model: EvalsTo<String, String>,
278278
#[serde(skip_serializing_if = "Option::is_none")]
279279
pub parameters: Option<HashMap<String, Value>>,
280280
#[serde(skip_serializing_if = "Option::is_none")]
@@ -468,6 +468,23 @@ pub enum EvalsTo<S, T> {
468468
Expr(Expr<S, T>),
469469
}
470470

471+
impl Default for EvalsTo<String, String> {
472+
fn default() -> Self {
473+
EvalsTo::Const("".to_string())
474+
}
475+
}
476+
477+
impl From<&str> for EvalsTo<String, String> {
478+
fn from(s: &str) -> Self {
479+
EvalsTo::Const(s.to_string())
480+
}
481+
}
482+
impl From<String> for EvalsTo<String, String> {
483+
fn from(s: String) -> Self {
484+
EvalsTo::Const(s)
485+
}
486+
}
487+
471488
/// Conditional control structure.
472489
///
473490
/// Example:

pdl-live-react/src-tauri/src/pdl/extract.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use crate::pdl::ast::{Block, Body::*, Metadata, PdlBlock, PdlBlock::Advanced};
1+
use crate::pdl::ast::{
2+
Block, Body::*, EvalsTo, Expr, Metadata, ModelBlock, PdlBlock, PdlBlock::Advanced,
3+
};
24

35
/// Extract models referenced by the programs
46
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
@@ -28,7 +30,30 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
2830
PdlBlock::Function(b) => {
2931
extract_values_iter(&b.return_, field, values);
3032
}
31-
Advanced(Block { body: Model(b), .. }) => values.push(b.model.clone()),
33+
Advanced(Block {
34+
body:
35+
Model(ModelBlock {
36+
model: EvalsTo::<String, String>::Const(m),
37+
..
38+
}),
39+
..
40+
}) => values.push(m.clone()),
41+
Advanced(Block {
42+
body:
43+
Model(ModelBlock {
44+
model: EvalsTo::<String, String>::Jinja(m),
45+
..
46+
}),
47+
..
48+
}) => values.push(m.clone()),
49+
Advanced(Block {
50+
body:
51+
Model(ModelBlock {
52+
model: EvalsTo::<String, String>::Expr(Expr { pdl_expr: m, .. }),
53+
..
54+
}),
55+
..
56+
}) => values.push(m.clone()),
3257
Advanced(Block {
3358
body: Repeat(b), ..
3459
}) => {

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 142 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,18 @@ impl<'a> Interpreter<'a> {
261261
}
262262
}
263263

264+
// TODO how can we better cope with the expected String return?
265+
fn eval_string_to_string(
266+
&self,
267+
expr: &EvalsTo<String, String>,
268+
state: &State,
269+
) -> Result<PdlResult, PdlError> {
270+
match expr {
271+
EvalsTo::Const(s) | EvalsTo::Jinja(s) => self.eval(s, state),
272+
EvalsTo::Expr(e) => self.eval(&e.pdl_expr, state),
273+
}
274+
}
275+
264276
/// Evaluate an Expr to a bool
265277
fn eval_to_bool(
266278
&self,
@@ -722,140 +734,149 @@ impl<'a> Interpreter<'a> {
722734
})
723735
}
724736

725-
/// Run a PdlBlock::Model
726-
async fn run_model(
737+
async fn run_ollama_model(
727738
&mut self,
739+
pdl_model: String,
728740
block: &ModelBlock,
729741
metadata: &Metadata,
730742
state: &mut State,
731743
) -> BodyInterpretation {
732-
match &block.model {
733-
pdl_model
734-
if pdl_model.starts_with("ollama/") || pdl_model.starts_with("ollama_chat/") =>
735-
{
736-
let mut ollama = Ollama::default();
737-
let model = if pdl_model.starts_with("ollama/") {
738-
&pdl_model[7..]
739-
} else {
740-
&pdl_model[12..]
741-
};
744+
let mut ollama = Ollama::default();
745+
let model = if pdl_model.starts_with("ollama/") {
746+
&pdl_model[7..]
747+
} else {
748+
&pdl_model[12..]
749+
};
742750

743-
let (options, tools) = self.to_ollama_model_options(&block.parameters);
744-
if self.options.debug {
745-
eprintln!("Model options {:?} {:?}", metadata.description, options);
746-
eprintln!("Model tools {:?} {:?}", metadata.description, tools);
747-
}
751+
let (options, tools) = self.to_ollama_model_options(&block.parameters);
752+
if self.options.debug {
753+
eprintln!("Model options {:?} {:?}", metadata.description, options);
754+
eprintln!("Model tools {:?} {:?}", metadata.description, tools);
755+
}
748756

749-
// The input messages to the model is either:
750-
// a) block.input, if given
751-
// b) the current state's accumulated messages
752-
let input_messages = match &block.input {
753-
Some(input) => {
754-
// TODO ignoring result, trace
755-
let (_result, messages, _trace) = self.run_quiet(&*input, state).await?;
756-
messages
757-
}
758-
None => state.messages.clone(),
759-
};
760-
let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) =
761-
match input_messages.split_last() {
762-
Some(x) => x,
763-
None => (&ChatMessage::user("".into()), &[]),
764-
};
765-
let mut history = Vec::from(history_slice);
766-
if self.options.debug {
767-
eprintln!(
768-
"Ollama {:?} model={:?} prompt={:?} history={:?}",
769-
metadata.description, block.model, prompt, history
770-
);
771-
}
757+
// The input messages to the model is either:
758+
// a) block.input, if given
759+
// b) the current state's accumulated messages
760+
let input_messages = match &block.input {
761+
Some(input) => {
762+
// TODO ignoring result, trace
763+
let (_result, messages, _trace) = self.run_quiet(&*input, state).await?;
764+
messages
765+
}
766+
None => state.messages.clone(),
767+
};
768+
let (prompt, history_slice): (&ChatMessage, &[ChatMessage]) =
769+
match input_messages.split_last() {
770+
Some(x) => x,
771+
None => (&ChatMessage::user("".into()), &[]),
772+
};
773+
let mut history = Vec::from(history_slice);
774+
if self.options.debug {
775+
eprintln!(
776+
"Ollama {:?} model={:?} prompt={:?} history={:?}",
777+
metadata.description, block.model, prompt, history
778+
);
779+
}
772780

773-
//if state.emit {
774-
//println!("{}", pretty_print(&input_messages));
775-
//}
776-
777-
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
778-
.options(options)
779-
.tools(tools);
780-
781-
let (last_res, response_string) = if !self.options.stream {
782-
let res = ollama
783-
.send_chat_messages_with_history(&mut history, req)
784-
.await?;
785-
let response_string = res.message.content.clone();
786-
print!("{}", response_string);
787-
(Some(res), response_string)
788-
} else {
789-
let mut stream = ollama
790-
.send_chat_messages_with_history_stream(
791-
::std::sync::Arc::new(::std::sync::Mutex::new(history)),
792-
req,
793-
//ollama.generate(GenerationRequest::new(model.into(), prompt),
794-
)
795-
.await?;
796-
// dbg!("Model result {:?}", &res);
797-
798-
let emit = if let Some(_) = &block.model_response {
799-
false
800-
} else {
801-
true
802-
};
803-
804-
let mut last_res: Option<ChatMessageResponse> = None;
805-
let mut response_string = String::new();
806-
let mut stdout = stdout();
807-
if emit {
808-
stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?;
809-
}
810-
while let Some(Ok(res)) = stream.next().await {
811-
if emit {
812-
stdout.write_all(b"\x1b[32m").await?; // green
813-
stdout.write_all(res.message.content.as_bytes()).await?;
814-
stdout.flush().await?;
815-
stdout.write_all(b"\x1b[0m").await?; // reset color
816-
}
817-
response_string += res.message.content.as_str();
818-
last_res = Some(res);
819-
}
820-
if emit {
821-
stdout.write_all(b"\n").await?;
822-
}
781+
//if state.emit {
782+
//println!("{}", pretty_print(&input_messages));
783+
//}
784+
785+
let req = ChatMessageRequest::new(model.into(), vec![prompt.clone()])
786+
.options(options)
787+
.tools(tools);
788+
789+
let (last_res, response_string) = if !self.options.stream {
790+
let res = ollama
791+
.send_chat_messages_with_history(&mut history, req)
792+
.await?;
793+
let response_string = res.message.content.clone();
794+
print!("{}", response_string);
795+
(Some(res), response_string)
796+
} else {
797+
let mut stream = ollama
798+
.send_chat_messages_with_history_stream(
799+
::std::sync::Arc::new(::std::sync::Mutex::new(history)),
800+
req,
801+
//ollama.generate(GenerationRequest::new(model.into(), prompt),
802+
)
803+
.await?;
804+
// dbg!("Model result {:?}", &res);
823805

824-
(last_res, response_string)
825-
};
826-
827-
if let Some(_) = &block.model_response {
828-
if let Some(ref res) = last_res {
829-
self.def(
830-
&block.model_response,
831-
&resultify_as_litellm(&from_str(&to_string(&res)?)?),
832-
&None,
833-
state,
834-
true,
835-
)?;
836-
}
837-
}
806+
let emit = if let Some(_) = &block.model_response {
807+
false
808+
} else {
809+
true
810+
};
838811

839-
let mut trace = block.clone();
840-
if let Some(res) = last_res {
841-
if let Some(usage) = res.final_data {
842-
trace.pdl_usage = Some(PdlUsage {
843-
prompt_tokens: usage.prompt_eval_count,
844-
prompt_nanos: Some(usage.prompt_eval_duration),
845-
completion_tokens: usage.eval_count,
846-
completion_nanos: Some(usage.eval_duration),
847-
});
848-
}
849-
let output_messages = vec![ChatMessage::assistant(response_string)];
850-
Ok((res.message.content.into(), output_messages, Model(trace)))
851-
} else {
852-
// nothing came out of the model
853-
Ok(("".into(), vec![], Model(trace)))
812+
let mut last_res: Option<ChatMessageResponse> = None;
813+
let mut response_string = String::new();
814+
let mut stdout = stdout();
815+
if emit {
816+
stdout.write_all(b"\x1b[1mAssistant: \x1b[0m").await?;
817+
}
818+
while let Some(Ok(res)) = stream.next().await {
819+
if emit {
820+
stdout.write_all(b"\x1b[32m").await?; // green
821+
stdout.write_all(res.message.content.as_bytes()).await?;
822+
stdout.flush().await?;
823+
stdout.write_all(b"\x1b[0m").await?; // reset color
854824
}
855-
// dbg!(history);
825+
response_string += res.message.content.as_str();
826+
last_res = Some(res);
827+
}
828+
if emit {
829+
stdout.write_all(b"\n").await?;
830+
}
831+
832+
(last_res, response_string)
833+
};
834+
835+
if let Some(_) = &block.model_response {
836+
if let Some(ref res) = last_res {
837+
self.def(
838+
&block.model_response,
839+
&resultify_as_litellm(&from_str(&to_string(&res)?)?),
840+
&None,
841+
state,
842+
true,
843+
)?;
856844
}
857-
_ => Err(Box::from(format!("Unsupported model {}", block.model))),
858845
}
846+
847+
let mut trace = block.clone();
848+
if let Some(res) = last_res {
849+
if let Some(usage) = res.final_data {
850+
trace.pdl_usage = Some(PdlUsage {
851+
prompt_tokens: usage.prompt_eval_count,
852+
prompt_nanos: Some(usage.prompt_eval_duration),
853+
completion_tokens: usage.eval_count,
854+
completion_nanos: Some(usage.eval_duration),
855+
});
856+
}
857+
let output_messages = vec![ChatMessage::assistant(response_string)];
858+
Ok((res.message.content.into(), output_messages, Model(trace)))
859+
} else {
860+
// nothing came out of the model
861+
Ok(("".into(), vec![], Model(trace)))
862+
}
863+
// dbg!(history);
864+
}
865+
866+
/// Run a PdlBlock::Model
867+
async fn run_model(
868+
&mut self,
869+
block: &ModelBlock,
870+
metadata: &Metadata,
871+
state: &mut State,
872+
) -> BodyInterpretation {
873+
if let PdlResult::String(s) = self.eval_string_to_string(&block.model, state)? {
874+
if s.starts_with("ollama/") || s.starts_with("ollama_chat/") {
875+
return self.run_ollama_model(s, block, metadata, state).await;
876+
}
877+
}
878+
879+
Err(Box::from(format!("Unsupported model {:?}", block.model)))
859880
}
860881

861882
/// Run a PdlBlock::Data

pdl-live-react/src-tauri/src/pdl/interpreter_tests.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,26 @@ mod tests {
7979
Ok(())
8080
}
8181

82+
#[test]
83+
fn single_model_via_text_chain_expr() -> Result<(), Box<dyn Error>> {
84+
let (_, messages, _) = run_json(
85+
json!({
86+
"text": [
87+
"hello",
88+
{"model": { "pdl__expr": DEFAULT_MODEL }}
89+
]
90+
}),
91+
streaming(),
92+
initial_scope(),
93+
)?;
94+
assert_eq!(messages.len(), 2);
95+
assert_eq!(messages[0].role, MessageRole::User);
96+
assert_eq!(messages[0].content, "hello");
97+
assert_eq!(messages[1].role, MessageRole::Assistant);
98+
assert!(messages[1].content.contains("Hello!"));
99+
Ok(())
100+
}
101+
82102
#[test]
83103
fn single_model_via_text_chain() -> Result<(), Box<dyn Error>> {
84104
let (_, messages, _) = run_json(

0 commit comments

Comments
 (0)