Skip to content

Commit b4ab2bf

Browse files
committed
fix: add kind tags to rust ast blocks
Signed-off-by: Nick Mitchell <[email protected]>
1 parent 828cefa commit b4ab2bf

File tree

5 files changed

+208
-154
lines changed

5 files changed

+208
-154
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 74 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
1515
ArrayBlock, CallBlock, FunctionBlock, ListOrString, MessageBlock, ModelBlock, ObjectBlock,
16-
PdlBaseType, PdlBlock, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role,
17-
TextBlock,
16+
PdlAdvancedBlock, PdlBaseType, PdlBlock, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock,
17+
RepeatBlock, Role, TextBlock,
1818
};
1919
use crate::pdl::pip::pip_install_if_needed;
2020
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -190,35 +190,39 @@ fn with_tools(
190190
}
191191

192192
fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
193-
let repeat = PdlBlock::Text(TextBlock {
193+
let repeat = PdlBlock::Advanced(PdlAdvancedBlock::Text(TextBlock {
194194
def: None,
195195
defs: None,
196196
role: None,
197197
parser: None,
198198
description: Some("Calling tool ${ tool.function.name }".to_string()),
199-
text: vec![PdlBlock::Model(
199+
text: vec![PdlBlock::Advanced(PdlAdvancedBlock::Model(
200200
ModelBlock::new(model.as_str())
201201
.parameters(&strip_nulls(parameters))
202-
.input(PdlBlock::Array(ArrayBlock {
203-
array: vec![PdlBlock::Message(MessageBlock {
204-
role: Role::Tool,
205-
description: None,
206-
name: Some("${ tool.function.name }".to_string()),
207-
tool_call_id: Some("${ tool.id }".to_string()),
208-
content: Box::new(PdlBlock::Call(CallBlock {
209-
defs: json_loads(
210-
&"args",
211-
&"pdl__args",
212-
&"${ tool.function.arguments }",
213-
),
214-
call: "${ pdl__tools[tool.function.name] }".to_string(), // look up tool in tool_declarations def (see below)
215-
args: Some("${ args }".into()), // invoke with arguments as specified by the model
216-
})),
217-
})],
218-
}))
202+
.input(PdlBlock::Advanced(PdlAdvancedBlock::Array(ArrayBlock {
203+
array: vec![PdlBlock::Advanced(PdlAdvancedBlock::Message(
204+
MessageBlock {
205+
role: Role::Tool,
206+
description: None,
207+
name: Some("${ tool.function.name }".to_string()),
208+
tool_call_id: Some("${ tool.id }".to_string()),
209+
content: Box::new(PdlBlock::Advanced(PdlAdvancedBlock::Call(
210+
CallBlock {
211+
defs: json_loads(
212+
&"args",
213+
&"pdl__args",
214+
&"${ tool.function.arguments }",
215+
),
216+
call: "${ pdl__tools[tool.function.name] }".to_string(), // look up tool in tool_declarations def (see below)
217+
args: Some("${ args }".into()), // invoke with arguments as specified by the model
218+
},
219+
))),
220+
},
221+
))],
222+
})))
219223
.build(),
220-
)],
221-
});
224+
))],
225+
}));
222226

223227
let mut for_ = HashMap::new();
224228
for_.insert(
@@ -227,10 +231,10 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
227231
);
228232

229233
// response.choices[0].message.tool_calls
230-
PdlBlock::Repeat(RepeatBlock {
234+
PdlBlock::Advanced(PdlAdvancedBlock::Repeat(RepeatBlock {
231235
for_: for_,
232236
repeat: Box::new(repeat),
233-
})
237+
}))
234238
}
235239

236240
fn json_loads(
@@ -241,15 +245,15 @@ fn json_loads(
241245
let mut m = indexmap::IndexMap::new();
242246
m.insert(
243247
outer_name.to_owned(),
244-
PdlBlock::Text(
248+
PdlBlock::Advanced(PdlAdvancedBlock::Text(
245249
TextBlock::new(vec![PdlBlock::String(format!(
246250
"{{\"{}\": {}}}",
247251
inner_name, value
248252
))])
249253
.description(format!("Parsing json for {}={}", inner_name, value))
250254
.parser(PdlParser::Json)
251255
.build(),
252-
),
256+
)),
253257
);
254258
Some(m)
255259
}
@@ -396,13 +400,14 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result<PdlBlock, Box<dyn
396400
.map(|((import_from, import_fn), tool_name, schema)| {
397401
(
398402
tool_name.clone(),
399-
PdlBlock::Function(FunctionBlock {
403+
PdlBlock::Advanced(PdlAdvancedBlock::Function(FunctionBlock {
400404
function: schema,
401-
return_: Box::new(PdlBlock::PythonCode(PythonCodeBlock {
402-
// tool function definition
403-
lang: "python".to_string(),
404-
code: format!(
405-
"
405+
return_: Box::new(PdlBlock::Advanced(PdlAdvancedBlock::PythonCode(
406+
PythonCodeBlock {
407+
// tool function definition
408+
lang: "python".to_string(),
409+
code: format!(
410+
"
406411
from {} import {}
407412
import asyncio
408413
async def invoke():
@@ -414,25 +419,26 @@ async def invoke():
414419
{}
415420
asyncio.run(invoke())
416421
",
417-
import_from,
418-
import_fn,
419-
if debug {
420-
format!("print('Invoking tool {}')", tool_name)
421-
} else {
422-
"".to_string()
423-
},
424-
import_fn,
425-
if debug {
426-
format!(
427-
"print(f'Response from tool {}: {{result}}')",
428-
tool_name
429-
)
430-
} else {
431-
"".to_string()
432-
}
433-
),
434-
})),
435-
}),
422+
import_from,
423+
import_fn,
424+
if debug {
425+
format!("print('Invoking tool {}')", tool_name)
426+
} else {
427+
"".to_string()
428+
},
429+
import_fn,
430+
if debug {
431+
format!(
432+
"print(f'Response from tool {}: {{result}}')",
433+
tool_name
434+
)
435+
} else {
436+
"".to_string()
437+
}
438+
),
439+
},
440+
))),
441+
})),
436442
)
437443
})
438444
})
@@ -460,14 +466,14 @@ asyncio.run(invoke())
460466
let model = format!("{}/{}", provider, model);
461467

462468
if let Some(instructions) = instructions {
463-
model_call.push(PdlBlock::Text(TextBlock {
469+
model_call.push(PdlBlock::Advanced(PdlAdvancedBlock::Text(TextBlock {
464470
role: Some(Role::System),
465471
text: vec![PdlBlock::String(instructions)],
466472
def: None,
467473
defs: None,
468474
parser: None,
469475
description: Some("Model instructions".into()),
470-
}));
476+
})));
471477
}
472478

473479
let model_response = if let Some(tools) = &tools {
@@ -479,7 +485,7 @@ asyncio.run(invoke())
479485
None
480486
};
481487

482-
model_call.push(PdlBlock::Model(ModelBlock {
488+
model_call.push(PdlBlock::Advanced(PdlAdvancedBlock::Model(ModelBlock {
483489
input: None,
484490
description: Some(description),
485491
def: None,
@@ -488,7 +494,7 @@ asyncio.run(invoke())
488494
pdl_result: None,
489495
pdl_usage: None,
490496
parameters: Some(with_tools(&tools, &parameters.state.dict)),
491-
}));
497+
})));
492498

493499
if let Some(tools) = tools {
494500
if tools.len() > 0 {
@@ -500,29 +506,28 @@ asyncio.run(invoke())
500506
let mut defs = indexmap::IndexMap::new();
501507
defs.insert(
502508
closure_name.clone(),
503-
PdlBlock::Function(FunctionBlock {
509+
PdlBlock::Advanced(PdlAdvancedBlock::Function(FunctionBlock {
504510
function: HashMap::new(),
505-
return_: Box::new(PdlBlock::Text(TextBlock {
511+
return_: Box::new(PdlBlock::Advanced(PdlAdvancedBlock::Text(TextBlock {
506512
def: None,
507513
defs: None,
508514
role: None,
509515
parser: None,
510516
description: Some(format!("Model call {}", &model)),
511517
text: model_call,
512-
})),
513-
}),
518+
}))),
519+
})),
514520
);
515-
PdlBlock::Text(TextBlock {
521+
PdlBlock::Advanced(PdlAdvancedBlock::Text(TextBlock {
516522
def: None,
517523
defs: Some(defs),
518524
role: None,
519525
parser: None,
520526
description: Some("Model call wrapper".to_string()),
521-
text: vec![PdlBlock::Call(CallBlock::new(format!(
522-
"${{ {} }}",
523-
closure_name
527+
text: vec![PdlBlock::Advanced(PdlAdvancedBlock::Call(CallBlock::new(
528+
format!("${{ {} }}", closure_name),
524529
)))],
525-
})
530+
}))
526531
},
527532
)
528533
.collect::<Vec<_>>();
@@ -531,25 +536,25 @@ asyncio.run(invoke())
531536
.flat_map(|(a, b)| [a, b])
532537
.collect::<Vec<_>>();
533538

534-
let pdl: PdlBlock = PdlBlock::Text(TextBlock {
539+
let pdl: PdlBlock = PdlBlock::Advanced(PdlAdvancedBlock::Text(TextBlock {
535540
def: None,
536541
defs: if tool_declarations.len() == 0 {
537542
None
538543
} else {
539544
let mut m = indexmap::IndexMap::new();
540545
m.insert(
541546
"pdl__tools".to_string(),
542-
PdlBlock::Object(ObjectBlock {
547+
PdlBlock::Advanced(PdlAdvancedBlock::Object(ObjectBlock {
543548
object: tool_declarations,
544-
}),
549+
})),
545550
);
546551
Some(m)
547552
},
548553
description: Some(bee.workflow.workflow.name),
549554
role: None,
550555
parser: None,
551556
text: body,
552-
});
557+
}));
553558

554559
Ok(pdl)
555560
}

0 commit comments

Comments
 (0)