Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions pdl-live-react/src-tauri/src/pdl/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub enum PdlType {

/// Call a function
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "call")]
pub struct CallBlock {
/// Function to call
pub call: String,
Expand Down Expand Up @@ -91,6 +92,7 @@ pub trait SequencingBlock {

/// Return the value of the last block if the list of blocks
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "lastOf")]
pub struct LastOfBlock {
/// Sequence of blocks to execute
#[serde(rename = "lastOf")]
Expand Down Expand Up @@ -158,6 +160,7 @@ impl SequencingBlock for LastOfBlock {
/// Create the concatenation of the stringify version of the result of
/// each block of the list of blocks.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "text")]
pub struct TextBlock {
/// Body of the text
pub text: Vec<PdlBlock>,
Expand Down Expand Up @@ -260,6 +263,7 @@ impl From<Vec<PdlBlock>> for TextBlock {
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "function")]
pub struct FunctionBlock {
pub function: HashMap<String, PdlType>,
#[serde(rename = "return")]
Expand All @@ -279,6 +283,7 @@ pub struct PdlUsage {
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "model")]
pub struct ModelBlock {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
Expand Down Expand Up @@ -352,6 +357,7 @@ pub enum ListOrString {
/// "${ name }'s number is ${ number }\\n"
/// ```
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "repeat")]
pub struct RepeatBlock {
/// Arrays to iterate over
#[serde(rename = "for")]
Expand All @@ -363,6 +369,7 @@ pub struct RepeatBlock {

/// Create a message
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "message")]
pub struct MessageBlock {
/// Role of associated to the message, e.g. User or Assistant
pub role: Role,
Expand All @@ -386,6 +393,7 @@ pub struct MessageBlock {
/// block. If the body of the object is an array, the resulting object
/// is the union of the objects computed by each element of the array.
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "object")]
pub struct ObjectBlock {
pub object: HashMap<String, PdlBlock>,
}
Expand Down Expand Up @@ -413,6 +421,7 @@ pub struct ObjectBlock {
/// def: EXTRACTED_GROUND_TRUTH
/// ```
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind")]
pub struct DataBlock {
pub data: Value,

Expand All @@ -438,6 +447,7 @@ pub struct DataBlock {
/// result = random.randint(1, 20)
/// ```
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "code")]
pub struct PythonCodeBlock {
pub lang: String,
pub code: String,
Expand All @@ -464,6 +474,7 @@ pub enum StringOrNull {
/// parser: yaml
/// ```
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "read")]
pub struct ReadBlock {
/// Name of the file to read. If `None`, read the standard input.
pub read: StringOrNull,
Expand Down Expand Up @@ -500,6 +511,7 @@ pub enum StringOrBoolean {
/// then: You won!
/// ```
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "if")]
pub struct IfBlock {
/// The condition to check
#[serde(rename = "if")]
Expand All @@ -519,31 +531,36 @@ pub struct IfBlock {

/// Return the array of values computed by each block of the list of blocks
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "array")]
pub struct ArrayBlock {
/// Elements of the array
pub array: Vec<PdlBlock>,
}

/// Include a PDL file
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "include")]
pub struct IncludeBlock {
/// Name of the file to include.
pub include: String,
}

/// Import a PDL file
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "import")]
pub struct ImportBlock {
/// Name of the file to include.
pub import: String,
}

/// Block containing only defs
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(tag = "kind", rename = "empty")]
pub struct EmptyBlock {
pub defs: IndexMap<String, PdlBlock>,
}

/// A PDL program/sub-program consists of either a literal (string, number, boolean) or some kind of structured block
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum PdlBlock {
Expand All @@ -570,6 +587,12 @@ pub enum PdlBlock {
Empty(EmptyBlock),
}

impl From<bool> for PdlBlock {
fn from(b: bool) -> Self {
PdlBlock::Bool(b)
}
}

impl From<&str> for PdlBlock {
fn from(s: &str) -> Self {
PdlBlock::String(s.into())
Expand Down
22 changes: 9 additions & 13 deletions pdl-live-react/src-tauri/src/pdl/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ impl<'a> Interpreter<'a> {
self.emit = emit;

let (result, messages, trace) = match program {
PdlBlock::Bool(b) => Ok((
b.into(),
vec![ChatMessage::user(format!("{b}"))],
PdlBlock::Bool(b.clone()),
)),
PdlBlock::Number(n) => Ok((
n.clone().into(),
vec![ChatMessage::user(format!("{n}"))],
Expand All @@ -113,7 +118,6 @@ impl<'a> Interpreter<'a> {
PdlBlock::Text(block) => self.run_sequence(block, context).await,
PdlBlock::Array(block) => self.run_array(block, context).await,
PdlBlock::Message(block) => self.run_message(block, context).await,
_ => Err(Box::from(format!("Unsupported block {:?}", program))),
}?;

if match program {
Expand Down Expand Up @@ -324,10 +328,7 @@ impl<'a> Interpreter<'a> {
self.push_and_extend_scope_with(m, c.scope);
Ok(())
}
x => Err(PdlError::from(format!(
"Call arguments not a map: {:?}",
x
))),
x => Err(PdlError::from(format!("Call arguments not a map: {:?}", x))),
}?;
}

Expand Down Expand Up @@ -513,9 +514,7 @@ impl<'a> Interpreter<'a> {
Ok(_) => Ok(()),
Err(exc) => {
vm.print_exception(exc);
Err(PdlError::from(
"Error executing Python code",
))
Err(PdlError::from("Error executing Python code"))
}
}?;

Expand All @@ -525,9 +524,7 @@ impl<'a> Interpreter<'a> {
Ok(x) => Ok(x),
Err(exc) => {
vm.print_exception(exc);
Err(PdlError::from(
"Unable to stringify Python 'result' value",
))
Err(PdlError::from("Unable to stringify Python 'result' value"))
}
}?;
let messages = vec![ChatMessage::user(result_string.as_str().to_string())];
Expand Down Expand Up @@ -998,8 +995,7 @@ pub fn run_sync(

/// Read in a file from disk and parse it as a PDL program
pub fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
from_reader(::std::fs::File::open(path)?)
.map_err(|err| PdlError::from(err.to_string()))
from_reader(::std::fs::File::open(path)?).map_err(|err| PdlError::from(err.to_string()))
}

pub async fn run_file(source_file_path: &str, debug: bool, stream: bool) -> Interpretation {
Expand Down