Skip to content

Commit a8cac9b

Browse files
committed
fix: add kind tags to rust ast blocks
This also updates the interpreter so that we get a compile time error if the match-over-block type is not exhaustive. Signed-off-by: Nick Mitchell <[email protected]>
1 parent 828cefa commit a8cac9b

File tree

2 files changed

+32
-13
lines changed

2 files changed

+32
-13
lines changed

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ pub enum PdlType {
5353

5454
/// Call a function
5555
#[derive(Serialize, Deserialize, Debug, Clone)]
56+
#[serde(tag = "kind", rename = "call")]
5657
pub struct CallBlock {
5758
/// Function to call
5859
pub call: String,
@@ -91,6 +92,7 @@ pub trait SequencingBlock {
9192

9293
/// Return the value of the last block if the list of blocks
9394
#[derive(Serialize, Deserialize, Debug, Clone)]
95+
#[serde(tag = "kind", rename = "lastOf")]
9496
pub struct LastOfBlock {
9597
/// Sequence of blocks to execute
9698
#[serde(rename = "lastOf")]
@@ -158,6 +160,7 @@ impl SequencingBlock for LastOfBlock {
158160
/// Create the concatenation of the stringify version of the result of
159161
/// each block of the list of blocks.
160162
#[derive(Serialize, Deserialize, Debug, Clone)]
163+
#[serde(tag = "kind", rename = "text")]
161164
pub struct TextBlock {
162165
/// Body of the text
163166
pub text: Vec<PdlBlock>,
@@ -260,6 +263,7 @@ impl From<Vec<PdlBlock>> for TextBlock {
260263
}
261264

262265
#[derive(Serialize, Deserialize, Debug, Clone)]
266+
#[serde(tag = "kind", rename = "function")]
263267
pub struct FunctionBlock {
264268
pub function: HashMap<String, PdlType>,
265269
#[serde(rename = "return")]
@@ -279,6 +283,7 @@ pub struct PdlUsage {
279283
}
280284

281285
#[derive(Serialize, Deserialize, Debug, Clone)]
286+
#[serde(tag = "kind", rename = "model")]
282287
pub struct ModelBlock {
283288
#[serde(skip_serializing_if = "Option::is_none")]
284289
pub description: Option<String>,
@@ -352,6 +357,7 @@ pub enum ListOrString {
352357
/// "${ name }'s number is ${ number }\\n"
353358
/// ```
354359
#[derive(Serialize, Deserialize, Debug, Clone)]
360+
#[serde(tag = "kind", rename = "repeat")]
355361
pub struct RepeatBlock {
356362
/// Arrays to iterate over
357363
#[serde(rename = "for")]
@@ -363,6 +369,7 @@ pub struct RepeatBlock {
363369

364370
/// Create a message
365371
#[derive(Serialize, Deserialize, Debug, Clone)]
372+
#[serde(tag = "kind", rename = "message")]
366373
pub struct MessageBlock {
367374
/// Role of associated to the message, e.g. User or Assistant
368375
pub role: Role,
@@ -386,6 +393,7 @@ pub struct MessageBlock {
386393
/// block. If the body of the object is an array, the resulting object
387394
/// is the union of the objects computed by each element of the array.
388395
#[derive(Serialize, Deserialize, Debug, Clone)]
396+
#[serde(tag = "kind", rename = "object")]
389397
pub struct ObjectBlock {
390398
pub object: HashMap<String, PdlBlock>,
391399
}
@@ -413,6 +421,7 @@ pub struct ObjectBlock {
413421
/// def: EXTRACTED_GROUND_TRUTH
414422
/// ```
415423
#[derive(Serialize, Deserialize, Debug, Clone)]
424+
#[serde(tag = "kind")]
416425
pub struct DataBlock {
417426
pub data: Value,
418427

@@ -438,6 +447,7 @@ pub struct DataBlock {
438447
/// result = random.randint(1, 20)
439448
/// ```
440449
#[derive(Serialize, Deserialize, Debug, Clone)]
450+
#[serde(tag = "kind", rename = "code")]
441451
pub struct PythonCodeBlock {
442452
pub lang: String,
443453
pub code: String,
@@ -464,6 +474,7 @@ pub enum StringOrNull {
464474
/// parser: yaml
465475
/// ```
466476
#[derive(Serialize, Deserialize, Debug, Clone)]
477+
#[serde(tag = "kind", rename = "read")]
467478
pub struct ReadBlock {
468479
/// Name of the file to read. If `None`, read the standard input.
469480
pub read: StringOrNull,
@@ -500,6 +511,7 @@ pub enum StringOrBoolean {
500511
/// then: You won!
501512
/// ```
502513
#[derive(Serialize, Deserialize, Debug, Clone)]
514+
#[serde(tag = "kind", rename = "if")]
503515
pub struct IfBlock {
504516
/// The condition to check
505517
#[serde(rename = "if")]
@@ -519,31 +531,36 @@ pub struct IfBlock {
519531

520532
/// Return the array of values computed by each block of the list of blocks
521533
#[derive(Serialize, Deserialize, Debug, Clone)]
534+
#[serde(tag = "kind", rename = "array")]
522535
pub struct ArrayBlock {
523536
/// Elements of the array
524537
pub array: Vec<PdlBlock>,
525538
}
526539

527540
/// Include a PDL file
528541
#[derive(Serialize, Deserialize, Debug, Clone)]
542+
#[serde(tag = "kind", rename = "include")]
529543
pub struct IncludeBlock {
530544
/// Name of the file to include.
531545
pub include: String,
532546
}
533547

534548
/// Import a PDL file
535549
#[derive(Serialize, Deserialize, Debug, Clone)]
550+
#[serde(tag = "kind", rename = "import")]
536551
pub struct ImportBlock {
537552
/// Name of the file to include.
538553
pub import: String,
539554
}
540555

541556
/// Block containing only defs
542557
#[derive(Serialize, Deserialize, Debug, Clone)]
558+
#[serde(tag = "kind", rename = "empty")]
543559
pub struct EmptyBlock {
544560
pub defs: IndexMap<String, PdlBlock>,
545561
}
546562

563+
/// A PDL program/sub-program consists of either a literal (string, number, boolean) or some kind of structured block
547564
#[derive(Serialize, Deserialize, Debug, Clone)]
548565
#[serde(untagged)]
549566
pub enum PdlBlock {
@@ -570,6 +587,12 @@ pub enum PdlBlock {
570587
Empty(EmptyBlock),
571588
}
572589

590+
impl From<bool> for PdlBlock {
591+
fn from(b: bool) -> Self {
592+
PdlBlock::Bool(b)
593+
}
594+
}
595+
573596
impl From<&str> for PdlBlock {
574597
fn from(s: &str) -> Self {
575598
PdlBlock::String(s.into())

pdl-live-react/src-tauri/src/pdl/interpreter.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ impl<'a> Interpreter<'a> {
8787
self.emit = emit;
8888

8989
let (result, messages, trace) = match program {
90+
PdlBlock::Bool(b) => Ok((
91+
b.into(),
92+
vec![ChatMessage::user(format!("{b}"))],
93+
PdlBlock::Bool(b.clone()),
94+
)),
9095
PdlBlock::Number(n) => Ok((
9196
n.clone().into(),
9297
vec![ChatMessage::user(format!("{n}"))],
@@ -113,7 +118,6 @@ impl<'a> Interpreter<'a> {
113118
PdlBlock::Text(block) => self.run_sequence(block, context).await,
114119
PdlBlock::Array(block) => self.run_array(block, context).await,
115120
PdlBlock::Message(block) => self.run_message(block, context).await,
116-
_ => Err(Box::from(format!("Unsupported block {:?}", program))),
117121
}?;
118122

119123
if match program {
@@ -324,10 +328,7 @@ impl<'a> Interpreter<'a> {
324328
self.push_and_extend_scope_with(m, c.scope);
325329
Ok(())
326330
}
327-
x => Err(PdlError::from(format!(
328-
"Call arguments not a map: {:?}",
329-
x
330-
))),
331+
x => Err(PdlError::from(format!("Call arguments not a map: {:?}", x))),
331332
}?;
332333
}
333334

@@ -513,9 +514,7 @@ impl<'a> Interpreter<'a> {
513514
Ok(_) => Ok(()),
514515
Err(exc) => {
515516
vm.print_exception(exc);
516-
Err(PdlError::from(
517-
"Error executing Python code",
518-
))
517+
Err(PdlError::from("Error executing Python code"))
519518
}
520519
}?;
521520

@@ -525,9 +524,7 @@ impl<'a> Interpreter<'a> {
525524
Ok(x) => Ok(x),
526525
Err(exc) => {
527526
vm.print_exception(exc);
528-
Err(PdlError::from(
529-
"Unable to stringify Python 'result' value",
530-
))
527+
Err(PdlError::from("Unable to stringify Python 'result' value"))
531528
}
532529
}?;
533530
let messages = vec![ChatMessage::user(result_string.as_str().to_string())];
@@ -998,8 +995,7 @@ pub fn run_sync(
998995

999996
/// Read in a file from disk and parse it as a PDL program
1000997
pub fn parse_file(path: &PathBuf) -> Result<PdlBlock, PdlError> {
1001-
from_reader(::std::fs::File::open(path)?)
1002-
.map_err(|err| PdlError::from(err.to_string()))
998+
from_reader(::std::fs::File::open(path)?).map_err(|err| PdlError::from(err.to_string()))
1003999
}
10041000

10051001
pub async fn run_file(source_file_path: &str, debug: bool, stream: bool) -> Interpretation {

0 commit comments

Comments
 (0)