Skip to content

Commit 36e9b68

Browse files
mandelstarpit
authored andcommitted
feat: update rust Repeat AST to use Expr for for attr (#904)
Signed-off-by: Louis Mandel <[email protected]> Signed-off-by: Nick Mitchell <[email protected]>
1 parent db348ab commit 36e9b68

File tree

5 files changed

+96
-96
lines changed

5 files changed

+96
-96
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use serde_json::{Map, Value, from_reader, json, to_string};
1212
use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
15-
ArrayBlockBuilder, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString, MessageBlock,
16-
MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlOptionalType,
17-
PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock, TextBlockBuilder,
15+
ArrayBlockBuilder, Block::*, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString,
16+
MessageBlock, MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock,
17+
PdlBlock::Advanced, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role,
18+
TextBlock, TextBlockBuilder,
1819
};
1920
use crate::pdl::pip::pip_install_if_needed;
2021
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -190,7 +191,7 @@ fn with_tools(
190191
}
191192

192193
fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
193-
let repeat = PdlBlock::Text(TextBlock {
194+
let repeat = Advanced(Text(TextBlock {
194195
metadata: Some(
195196
MetadataBuilder::default()
196197
.description("Calling tool ${ tool.function.name }".to_string())
@@ -199,19 +200,19 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
199200
),
200201
role: None,
201202
parser: None,
202-
text: vec![PdlBlock::Model(
203+
text: vec![Advanced(Model(
203204
ModelBlockBuilder::default()
204205
.model(model.as_str())
205206
.parameters(strip_nulls(parameters))
206-
.input(PdlBlock::Array(
207+
.input(Advanced(Array(
207208
ArrayBlockBuilder::default()
208-
.array(vec![PdlBlock::Message(MessageBlock {
209+
.array(vec![Advanced(Message(MessageBlock {
209210
metadata: None,
210211
role: Role::Tool,
211212
defsite: None,
212213
name: Some("${ tool.function.name }".to_string()),
213214
tool_call_id: Some("${ tool.id }".to_string()),
214-
content: Box::new(PdlBlock::Call(CallBlock {
215+
content: Box::new(Advanced(Call(CallBlock {
215216
metadata: Some(
216217
MetadataBuilder::default()
217218
.defs(json_loads(
@@ -226,15 +227,15 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
226227
"${ pdl__tools[tool.function.name] }".to_string(),
227228
), // look up tool in tool_declarations def (see below)
228229
args: Some("${ args }".into()), // invoke with arguments as specified by the model
229-
})),
230-
})])
230+
}))),
231+
}))])
231232
.build()
232233
.unwrap(),
233-
))
234+
)))
234235
.build()
235236
.unwrap(),
236-
)],
237-
});
237+
))],
238+
}));
238239

239240
let mut for_ = HashMap::new();
240241
for_.insert(
@@ -248,11 +249,11 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
248249
);
249250

250251
// response.choices[0].message.tool_calls
251-
PdlBlock::Repeat(RepeatBlock {
252+
Advanced(Repeat(RepeatBlock {
252253
metadata: None,
253254
for_: for_,
254255
repeat: Box::new(repeat),
255-
})
256+
}))
256257
}
257258

258259
fn json_loads(
@@ -263,7 +264,7 @@ fn json_loads(
263264
let mut m = indexmap::IndexMap::new();
264265
m.insert(
265266
outer_name.to_owned(),
266-
PdlBlock::Text(
267+
Advanced(Text(
267268
TextBlockBuilder::default()
268269
.text(vec![PdlBlock::String(format!(
269270
"{{\"{}\": {}}}",
@@ -278,7 +279,7 @@ fn json_loads(
278279
.parser(PdlParser::Json)
279280
.build()
280281
.unwrap(),
281-
),
282+
)),
282283
);
283284
m
284285
}
@@ -427,7 +428,7 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result<PdlBlock, Box<dyn
427428
tool_name.clone(),
428429
PdlBlock::Function(FunctionBlock {
429430
function: schema,
430-
return_: Box::new(PdlBlock::PythonCode(PythonCodeBlock {
431+
return_: Box::new(Advanced(PythonCode(PythonCodeBlock {
431432
// tool function definition
432433
metadata: None,
433434
lang: "python".to_string(),
@@ -461,7 +462,7 @@ asyncio.run(invoke())
461462
"".to_string()
462463
}
463464
),
464-
})),
465+
}))),
465466
}),
466467
)
467468
})
@@ -490,7 +491,7 @@ asyncio.run(invoke())
490491
let model = format!("{}/{}", provider, model);
491492

492493
if let Some(instructions) = instructions {
493-
model_call.push(PdlBlock::Text(TextBlock {
494+
model_call.push(Advanced(Text(TextBlock {
494495
role: Some(Role::System),
495496
text: vec![PdlBlock::String(instructions)],
496497
metadata: Some(
@@ -500,7 +501,7 @@ asyncio.run(invoke())
500501
.unwrap(),
501502
),
502503
parser: None,
503-
}));
504+
})));
504505
}
505506

506507
let mut model_builder = ModelBlockBuilder::default();
@@ -523,7 +524,7 @@ asyncio.run(invoke())
523524
}
524525
}
525526

526-
model_call.push(PdlBlock::Model(model_builder.build().unwrap()));
527+
model_call.push(Advanced(Model(model_builder.build().unwrap())));
527528

528529
if let Some(tools) = tools {
529530
if tools.len() > 0 {
@@ -537,7 +538,7 @@ asyncio.run(invoke())
537538
closure_name.clone(),
538539
PdlBlock::Function(FunctionBlock {
539540
function: HashMap::new(),
540-
return_: Box::new(PdlBlock::Text(TextBlock {
541+
return_: Box::new(Advanced(Text(TextBlock {
541542
metadata: Some(
542543
MetadataBuilder::default()
543544
.description(format!("Model call {}", &model))
@@ -547,10 +548,10 @@ asyncio.run(invoke())
547548
role: None,
548549
parser: None,
549550
text: model_call,
550-
})),
551+
}))),
551552
}),
552553
);
553-
PdlBlock::Text(TextBlock {
554+
Advanced(Text(TextBlock {
554555
metadata: Some(
555556
MetadataBuilder::default()
556557
.description("Model call wrapper".to_string())
@@ -560,11 +561,11 @@ asyncio.run(invoke())
560561
),
561562
role: None,
562563
parser: None,
563-
text: vec![PdlBlock::Call(CallBlock::new(format!(
564+
text: vec![Advanced(Call(CallBlock::new(format!(
564565
"${{ {} }}",
565566
closure_name
566-
)))],
567-
})
567+
))))],
568+
}))
568569
},
569570
)
570571
.collect::<Vec<_>>();
@@ -579,19 +580,19 @@ asyncio.run(invoke())
579580
let mut defs = indexmap::IndexMap::new();
580581
defs.insert(
581582
"pdl__tools".to_string(),
582-
PdlBlock::Object(ObjectBlock {
583+
Advanced(Object(ObjectBlock {
583584
object: tool_declarations,
584-
}),
585+
})),
585586
);
586587
metadata.defs(defs);
587588
}
588589

589-
let pdl: PdlBlock = PdlBlock::Text(TextBlock {
590+
let pdl: PdlBlock = Advanced(Text(TextBlock {
590591
metadata: Some(metadata.build().unwrap()),
591592
role: None,
592593
parser: None,
593594
text: body,
594-
});
595+
}));
595596

596597
Ok(pdl)
597598
}

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl SequencingBlock for LastOfBlock {
183183
&self.parser
184184
}
185185
fn to_block(&self) -> PdlBlock {
186-
PdlBlock::LastOf(self.clone())
186+
PdlBlock::Advanced(Block::LastOf(self.clone()))
187187
}
188188
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
189189
match output_results.last() {
@@ -243,7 +243,7 @@ impl SequencingBlock for TextBlock {
243243
&self.parser
244244
}
245245
fn to_block(&self) -> PdlBlock {
246-
PdlBlock::Text(self.clone())
246+
PdlBlock::Advanced(Block::Text(self.clone()))
247247
}
248248
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
249249
PdlResult::String(
@@ -596,8 +596,15 @@ pub enum PdlBlock {
596596
Number(Number),
597597
String(String),
598598
Function(FunctionBlock),
599+
Advanced(Block),
600+
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
601+
Empty(EmptyBlock),
602+
}
599603

600-
// the rest have Metadata; TODO refactor to make this more explicit
604+
/// A PDL block that has structure and metadata
605+
#[derive(Serialize, Deserialize, Debug, Clone)]
606+
#[serde(untagged)]
607+
pub enum Block {
601608
If(IfBlock),
602609
Import(ImportBlock),
603610
Include(IncludeBlock),
@@ -612,9 +619,6 @@ pub enum PdlBlock {
612619
Model(ModelBlock),
613620
LastOf(LastOfBlock),
614621
Text(TextBlock),
615-
616-
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
617-
Empty(EmptyBlock),
618622
}
619623

620624
impl From<bool> for PdlBlock {

pdl-live-react/src-tauri/src/pdl/extract.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::pdl::ast::{Metadata, PdlBlock};
1+
use crate::pdl::ast::{Block::*, Metadata, PdlBlock, PdlBlock::Advanced};
22

33
/// Extract models referenced by the programs
44
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
@@ -20,18 +20,26 @@ pub fn extract_values(program: &PdlBlock, field: &str) -> Vec<String> {
2020
/// Take one Yaml fragment and produce a vector of the string-valued entries of the given field
2121
fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>) {
2222
match program {
23-
PdlBlock::Model(b) => values.push(b.model.clone()),
24-
PdlBlock::Repeat(b) => {
23+
PdlBlock::Empty(b) => {
24+
b.defs
25+
.values()
26+
.for_each(|p| extract_values_iter(p, field, values));
27+
}
28+
PdlBlock::Function(b) => {
29+
extract_values_iter(&b.return_, field, values);
30+
}
31+
Advanced(Model(b)) => values.push(b.model.clone()),
32+
Advanced(Repeat(b)) => {
2533
extract_values_iter(&b.repeat, field, values);
2634
}
27-
PdlBlock::Message(b) => {
35+
Advanced(Message(b)) => {
2836
extract_values_iter(&b.content, field, values);
2937
}
30-
PdlBlock::Array(b) => b
38+
Advanced(Array(b)) => b
3139
.array
3240
.iter()
3341
.for_each(|p| extract_values_iter(p, field, values)),
34-
PdlBlock::Text(b) => {
42+
Advanced(Text(b)) => {
3543
b.text
3644
.iter()
3745
.for_each(|p| extract_values_iter(p, field, values));
@@ -43,7 +51,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
4351
.for_each(|p| extract_values_iter(p, field, values));
4452
}
4553
}
46-
PdlBlock::LastOf(b) => {
54+
Advanced(LastOf(b)) => {
4755
b.last_of
4856
.iter()
4957
.for_each(|p| extract_values_iter(p, field, values));
@@ -55,7 +63,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
5563
.for_each(|p| extract_values_iter(p, field, values));
5664
}
5765
}
58-
PdlBlock::If(b) => {
66+
Advanced(If(b)) => {
5967
extract_values_iter(&b.then, field, values);
6068
if let Some(else_) = &b.else_ {
6169
extract_values_iter(else_, field, values);
@@ -68,20 +76,11 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
6876
.for_each(|p| extract_values_iter(p, field, values));
6977
}
7078
}
71-
PdlBlock::Empty(b) => {
72-
b.defs
73-
.values()
74-
.for_each(|p| extract_values_iter(p, field, values));
75-
}
76-
PdlBlock::Object(b) => b
79+
Advanced(Object(b)) => b
7780
.object
7881
.values()
7982
.for_each(|p| extract_values_iter(p, field, values)),
8083

81-
PdlBlock::Function(b) => {
82-
extract_values_iter(&b.return_, field, values);
83-
}
84-
8584
_ => {}
8685
}
8786
}

0 commit comments

Comments
 (0)