Skip to content

Commit c812cc9

Browse files
committed
refactor: introduce Advanced enum to rust AST
Signed-off-by: Nick Mitchell <[email protected]>
1 parent a0855d3 commit c812cc9

File tree

5 files changed

+96
-96
lines changed

5 files changed

+96
-96
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use serde_json::{Map, Value, from_reader, json, to_string};
1212
use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
15-
ArrayBlockBuilder, CallBlock, EvalsTo, FunctionBlock, ListOrString, MessageBlock,
16-
MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlOptionalType,
17-
PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock, TextBlockBuilder,
15+
ArrayBlockBuilder, Block::*, CallBlock, EvalsTo, FunctionBlock, ListOrString, MessageBlock,
16+
MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlBlock::Advanced,
17+
PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock,
18+
TextBlockBuilder,
1819
};
1920
use crate::pdl::pip::pip_install_if_needed;
2021
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -190,7 +191,7 @@ fn with_tools(
190191
}
191192

192193
fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
193-
let repeat = PdlBlock::Text(TextBlock {
194+
let repeat = Advanced(Text(TextBlock {
194195
metadata: Some(
195196
MetadataBuilder::default()
196197
.description("Calling tool ${ tool.function.name }".to_string())
@@ -199,19 +200,19 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
199200
),
200201
role: None,
201202
parser: None,
202-
text: vec![PdlBlock::Model(
203+
text: vec![Advanced(Model(
203204
ModelBlockBuilder::default()
204205
.model(model.as_str())
205206
.parameters(strip_nulls(parameters))
206-
.input(PdlBlock::Array(
207+
.input(Advanced(Array(
207208
ArrayBlockBuilder::default()
208-
.array(vec![PdlBlock::Message(MessageBlock {
209+
.array(vec![Advanced(Message(MessageBlock {
209210
metadata: None,
210211
role: Role::Tool,
211212
defsite: None,
212213
name: Some("${ tool.function.name }".to_string()),
213214
tool_call_id: Some("${ tool.id }".to_string()),
214-
content: Box::new(PdlBlock::Call(CallBlock {
215+
content: Box::new(Advanced(Call(CallBlock {
215216
metadata: Some(
216217
MetadataBuilder::default()
217218
.defs(json_loads(
@@ -226,15 +227,15 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
226227
"${ pdl__tools[tool.function.name] }".to_string(),
227228
), // look up tool in tool_declarations def (see below)
228229
args: Some("${ args }".into()), // invoke with arguments as specified by the model
229-
})),
230-
})])
230+
}))),
231+
}))])
231232
.build()
232233
.unwrap(),
233-
))
234+
)))
234235
.build()
235236
.unwrap(),
236-
)],
237-
});
237+
))],
238+
}));
238239

239240
let mut for_ = HashMap::new();
240241
for_.insert(
@@ -243,11 +244,11 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
243244
);
244245

245246
// response.choices[0].message.tool_calls
246-
PdlBlock::Repeat(RepeatBlock {
247+
Advanced(Repeat(RepeatBlock {
247248
metadata: None,
248249
for_: for_,
249250
repeat: Box::new(repeat),
250-
})
251+
}))
251252
}
252253

253254
fn json_loads(
@@ -258,7 +259,7 @@ fn json_loads(
258259
let mut m = indexmap::IndexMap::new();
259260
m.insert(
260261
outer_name.to_owned(),
261-
PdlBlock::Text(
262+
Advanced(Text(
262263
TextBlockBuilder::default()
263264
.text(vec![PdlBlock::String(format!(
264265
"{{\"{}\": {}}}",
@@ -273,7 +274,7 @@ fn json_loads(
273274
.parser(PdlParser::Json)
274275
.build()
275276
.unwrap(),
276-
),
277+
)),
277278
);
278279
m
279280
}
@@ -422,7 +423,7 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result<PdlBlock, Box<dyn
422423
tool_name.clone(),
423424
PdlBlock::Function(FunctionBlock {
424425
function: schema,
425-
return_: Box::new(PdlBlock::PythonCode(PythonCodeBlock {
426+
return_: Box::new(Advanced(PythonCode(PythonCodeBlock {
426427
// tool function definition
427428
metadata: None,
428429
lang: "python".to_string(),
@@ -456,7 +457,7 @@ asyncio.run(invoke())
456457
"".to_string()
457458
}
458459
),
459-
})),
460+
}))),
460461
}),
461462
)
462463
})
@@ -485,7 +486,7 @@ asyncio.run(invoke())
485486
let model = format!("{}/{}", provider, model);
486487

487488
if let Some(instructions) = instructions {
488-
model_call.push(PdlBlock::Text(TextBlock {
489+
model_call.push(Advanced(Text(TextBlock {
489490
role: Some(Role::System),
490491
text: vec![PdlBlock::String(instructions)],
491492
metadata: Some(
@@ -495,7 +496,7 @@ asyncio.run(invoke())
495496
.unwrap(),
496497
),
497498
parser: None,
498-
}));
499+
})));
499500
}
500501

501502
let mut model_builder = ModelBlockBuilder::default();
@@ -518,7 +519,7 @@ asyncio.run(invoke())
518519
}
519520
}
520521

521-
model_call.push(PdlBlock::Model(model_builder.build().unwrap()));
522+
model_call.push(Advanced(Model(model_builder.build().unwrap())));
522523

523524
if let Some(tools) = tools {
524525
if tools.len() > 0 {
@@ -532,7 +533,7 @@ asyncio.run(invoke())
532533
closure_name.clone(),
533534
PdlBlock::Function(FunctionBlock {
534535
function: HashMap::new(),
535-
return_: Box::new(PdlBlock::Text(TextBlock {
536+
return_: Box::new(Advanced(Text(TextBlock {
536537
metadata: Some(
537538
MetadataBuilder::default()
538539
.description(format!("Model call {}", &model))
@@ -542,10 +543,10 @@ asyncio.run(invoke())
542543
role: None,
543544
parser: None,
544545
text: model_call,
545-
})),
546+
}))),
546547
}),
547548
);
548-
PdlBlock::Text(TextBlock {
549+
Advanced(Text(TextBlock {
549550
metadata: Some(
550551
MetadataBuilder::default()
551552
.description("Model call wrapper".to_string())
@@ -555,11 +556,11 @@ asyncio.run(invoke())
555556
),
556557
role: None,
557558
parser: None,
558-
text: vec![PdlBlock::Call(CallBlock::new(format!(
559+
text: vec![Advanced(Call(CallBlock::new(format!(
559560
"${{ {} }}",
560561
closure_name
561-
)))],
562-
})
562+
))))],
563+
}))
563564
},
564565
)
565566
.collect::<Vec<_>>();
@@ -574,19 +575,19 @@ asyncio.run(invoke())
574575
let mut defs = indexmap::IndexMap::new();
575576
defs.insert(
576577
"pdl__tools".to_string(),
577-
PdlBlock::Object(ObjectBlock {
578+
Advanced(Object(ObjectBlock {
578579
object: tool_declarations,
579-
}),
580+
})),
580581
);
581582
metadata.defs(defs);
582583
}
583584

584-
let pdl: PdlBlock = PdlBlock::Text(TextBlock {
585+
let pdl: PdlBlock = Advanced(Text(TextBlock {
585586
metadata: Some(metadata.build().unwrap()),
586587
role: None,
587588
parser: None,
588589
text: body,
589-
});
590+
}));
590591

591592
Ok(pdl)
592593
}

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl SequencingBlock for LastOfBlock {
183183
&self.parser
184184
}
185185
fn to_block(&self) -> PdlBlock {
186-
PdlBlock::LastOf(self.clone())
186+
PdlBlock::Advanced(Block::LastOf(self.clone()))
187187
}
188188
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
189189
match output_results.last() {
@@ -243,7 +243,7 @@ impl SequencingBlock for TextBlock {
243243
&self.parser
244244
}
245245
fn to_block(&self) -> PdlBlock {
246-
PdlBlock::Text(self.clone())
246+
PdlBlock::Advanced(Block::Text(self.clone()))
247247
}
248248
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
249249
PdlResult::String(
@@ -596,8 +596,15 @@ pub enum PdlBlock {
596596
Number(Number),
597597
String(String),
598598
Function(FunctionBlock),
599+
Advanced(Block),
600+
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
601+
Empty(EmptyBlock),
602+
}
599603

600-
// the rest have Metadata; TODO refactor to make this more explicit
604+
/// A PDL block that has structure and metadata
605+
#[derive(Serialize, Deserialize, Debug, Clone)]
606+
#[serde(untagged)]
607+
pub enum Block {
601608
If(IfBlock),
602609
Import(ImportBlock),
603610
Include(IncludeBlock),
@@ -612,9 +619,6 @@ pub enum PdlBlock {
612619
Model(ModelBlock),
613620
LastOf(LastOfBlock),
614621
Text(TextBlock),
615-
616-
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
617-
Empty(EmptyBlock),
618622
}
619623

620624
impl From<bool> for PdlBlock {

pdl-live-react/src-tauri/src/pdl/extract.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::pdl::ast::{Metadata, PdlBlock};
1+
use crate::pdl::ast::{Block::*, Metadata, PdlBlock, PdlBlock::Advanced};
22

33
/// Extract models referenced by the programs
44
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
@@ -20,18 +20,26 @@ pub fn extract_values(program: &PdlBlock, field: &str) -> Vec<String> {
2020
/// Take one Yaml fragment and produce a vector of the string-valued entries of the given field
2121
fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>) {
2222
match program {
23-
PdlBlock::Model(b) => values.push(b.model.clone()),
24-
PdlBlock::Repeat(b) => {
23+
PdlBlock::Empty(b) => {
24+
b.defs
25+
.values()
26+
.for_each(|p| extract_values_iter(p, field, values));
27+
}
28+
PdlBlock::Function(b) => {
29+
extract_values_iter(&b.return_, field, values);
30+
}
31+
Advanced(Model(b)) => values.push(b.model.clone()),
32+
Advanced(Repeat(b)) => {
2533
extract_values_iter(&b.repeat, field, values);
2634
}
27-
PdlBlock::Message(b) => {
35+
Advanced(Message(b)) => {
2836
extract_values_iter(&b.content, field, values);
2937
}
30-
PdlBlock::Array(b) => b
38+
Advanced(Array(b)) => b
3139
.array
3240
.iter()
3341
.for_each(|p| extract_values_iter(p, field, values)),
34-
PdlBlock::Text(b) => {
42+
Advanced(Text(b)) => {
3543
b.text
3644
.iter()
3745
.for_each(|p| extract_values_iter(p, field, values));
@@ -43,7 +51,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
4351
.for_each(|p| extract_values_iter(p, field, values));
4452
}
4553
}
46-
PdlBlock::LastOf(b) => {
54+
Advanced(LastOf(b)) => {
4755
b.last_of
4856
.iter()
4957
.for_each(|p| extract_values_iter(p, field, values));
@@ -55,7 +63,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
5563
.for_each(|p| extract_values_iter(p, field, values));
5664
}
5765
}
58-
PdlBlock::If(b) => {
66+
Advanced(If(b)) => {
5967
extract_values_iter(&b.then, field, values);
6068
if let Some(else_) = &b.else_ {
6169
extract_values_iter(else_, field, values);
@@ -68,20 +76,11 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
6876
.for_each(|p| extract_values_iter(p, field, values));
6977
}
7078
}
71-
PdlBlock::Empty(b) => {
72-
b.defs
73-
.values()
74-
.for_each(|p| extract_values_iter(p, field, values));
75-
}
76-
PdlBlock::Object(b) => b
79+
Advanced(Object(b)) => b
7780
.object
7881
.values()
7982
.for_each(|p| extract_values_iter(p, field, values)),
8083

81-
PdlBlock::Function(b) => {
82-
extract_values_iter(&b.return_, field, values);
83-
}
84-
8584
_ => {}
8685
}
8786
}

0 commit comments

Comments
 (0)