From b11e0501841e01c164ff7f3c94f6965a5a51be2d Mon Sep 17 00:00:00 2001 From: Sam Lijin Date: Tue, 24 Mar 2026 15:29:48 -0700 Subject: [PATCH] baml-lanugage/fix: use build_plan for llm functions, delete execute --- .../baml_builtins2/baml_std/baml/llm.baml | 22 +- .../baml_std/baml/llm_types.baml | 135 +++--------- .../crates/baml_compiler2_tir/src/builder.rs | 26 +++ .../crates/bex_engine/tests/orchestration.rs | 193 ++++++------------ 4 files changed, 139 insertions(+), 237 deletions(-) diff --git a/baml_language/crates/baml_builtins2/baml_std/baml/llm.baml b/baml_language/crates/baml_builtins2/baml_std/baml/llm.baml index d3297adc4c..d080483149 100644 --- a/baml_language/crates/baml_builtins2/baml_std/baml/llm.baml +++ b/baml_language/crates/baml_builtins2/baml_std/baml/llm.baml @@ -3,11 +3,6 @@ class OrchestrationStep { delay_ms int } -class ExecutionResult { - ok bool - value unknown -} - class ExecutionContext { jinja_string string args map @@ -58,6 +53,19 @@ function call_llm_function(client: Client, function_name: string, args: map 0) { + root.sys.sleep(step.delay_ms); + } + + let result: T = execute_step(step, context) catch (e) { + _ => { continue; } + }; + return result; + } + + throw root.errors.DevOther { message: "All orchestration steps failed" }; } diff --git a/baml_language/crates/baml_builtins2/baml_std/baml/llm_types.baml b/baml_language/crates/baml_builtins2/baml_std/baml/llm_types.baml index e7ee1ea068..7fd0e4b4c1 100644 --- a/baml_language/crates/baml_builtins2/baml_std/baml/llm_types.baml +++ b/baml_language/crates/baml_builtins2/baml_std/baml/llm_types.baml @@ -32,6 +32,20 @@ class Client { } } + function advance_round_robin(self) -> void { + match (self.client_type) { + ClientType.Primitive => {}, + ClientType.Fallback => { + for (let sub in self.sub_clients) { + sub.advance_round_robin(); + } + }, + ClientType.RoundRobin => { + self.counter += 1; + }, + } + } + function get_constructor( self ) -> () -> PrimitiveClient throws root.errors.InvalidArgument { @@ -120,109 +134,6 @@ class Client { } } - function execute( - self, - context: ExecutionContext, - inherited_delay_ms: int, - ) -> T { - match (self.retry) { - r: RetryPolicy => { - let current_delay = r.initial_delay_ms + 0.0 - - for (let attempt = 0; attempt <= r.max_retries; attempt += 1) { - let attempt_delay = inherited_delay_ms - if (attempt > 0) { - attempt_delay = root.math.trunc(current_delay) - let next = current_delay * r.multiplier - if (next > r.max_delay_ms + 0.0) { - current_delay = r.max_delay_ms + 0.0 - } else { - current_delay = next - } - } - - if (attempt == r.max_retries) { - attempt_delay = inherited_delay_ms - } - - let result2: T = self.execute_once( - context, - attempt_delay, - ) catch (e) { - _ => { continue; } - }; - - return result2; - } - - throw root.errors.DevOther { message: "All orchestration steps failed" }; - } - null => { - let result: T = self.execute_once( - context, - inherited_delay_ms, - ); - return result; - }, - } - } - - function execute_once( - self, - context: ExecutionContext, - active_delay_ms: int, - ) -> T { - match (self.client_type) { - ClientType.Primitive => { - let resolve_fn = self.get_constructor() - let primitive = resolve_fn() - - let prompt = primitive.render_prompt(context.jinja_string, context.args) - let specialized = primitive.specialize_prompt(prompt) - let http_request = primitive.build_request(specialized) - let http_response = root.http.send(http_request) - - if (http_response.ok()) { - let body = http_response.text() - let return_type = get_return_type(context.function_name) - let result: T = primitive.parse(body, return_type) catch (e) { - _ => { - if (active_delay_ms > 0) { - root.sys.sleep(active_delay_ms) - } - throw e; - } - }; - result - } else { - throw root.errors.DevOther { message: "HTTP request failed" }; - } - } - - ClientType.Fallback => { - for (let sub in self.sub_clients) { - let result2: T = sub.execute( - context, - active_delay_ms, - ) catch (e) { - _ => { continue; } - }; - return result2; - } - throw root.errors.DevOther { message: "All orchestration steps failed" }; - } - - ClientType.RoundRobin => { - let idx = self.counter % self.sub_clients.length() - self.counter += 1 - let result3: T = self.sub_clients[idx].execute( - context, - active_delay_ms, - ); - return result3; - } - } - } } class PrimitiveClientOptions { @@ -285,6 +196,24 @@ class PrimitiveClient { } } +function execute_step( + step: OrchestrationStep, + context: ExecutionContext, +) -> T { + let prompt = step.primitive_client.render_prompt(context.jinja_string, context.args); + let specialized = step.primitive_client.specialize_prompt(prompt); + let http_request = step.primitive_client.build_request(specialized); + let http_response = root.http.send(http_request); + + if (http_response.ok()) { + let body = http_response.text(); + let return_type = get_return_type(context.function_name); + step.primitive_client.parse(body, return_type) + } else { + throw root.errors.DevOther { message: "HTTP request failed" }; + } +} + function get_jinja_template(function_name: string) -> string throws root.errors.InvalidArgument { $rust_io_function } diff --git a/baml_language/crates/baml_compiler2_tir/src/builder.rs b/baml_language/crates/baml_compiler2_tir/src/builder.rs index 7787ec2cb1..74adc75d30 100644 --- a/baml_language/crates/baml_compiler2_tir/src/builder.rs +++ b/baml_language/crates/baml_compiler2_tir/src/builder.rs @@ -2258,6 +2258,32 @@ impl<'db> TypeInferenceBuilder<'db> { ), } } + Definition::Let(let_loc) => { + // Determine type from the let-binding's origin. + let db = self.context.db(); + let item_tree = + baml_compiler2_hir::file_item_tree(db, let_loc.file(db)); + let let_data = &item_tree[let_loc.id(db)]; + match let_data.origin { + baml_compiler2_ast::ast::LetOrigin::Client => { + // client declarations produce Client instances. + Ty::Class(crate::ty::QualifiedTypeName::new( + baml_base::Name::new("baml"), + vec![baml_base::Name::new("llm")], + baml_base::Name::new("Client"), + )) + } + baml_compiler2_ast::ast::LetOrigin::RetryPolicy => { + // retry_policy declarations produce RetryPolicy instances. + Ty::Class(crate::ty::QualifiedTypeName::new( + baml_base::Name::new("baml"), + vec![baml_base::Name::new("llm")], + baml_base::Name::new("RetryPolicy"), + )) + } + _ => Ty::Unknown, + } + } _ => Ty::Unknown, } } else if let Some(def) = self.package_items.lookup_type(&lookup_path) { diff --git a/baml_language/crates/bex_engine/tests/orchestration.rs b/baml_language/crates/bex_engine/tests/orchestration.rs index 94656c87cc..d216c00ca8 100644 --- a/baml_language/crates/bex_engine/tests/orchestration.rs +++ b/baml_language/crates/bex_engine/tests/orchestration.rs @@ -1,9 +1,8 @@ //! Integration tests for LLM orchestration plan building. //! -//! These tests verify that `baml.llm.build_plan` correctly expands client trees +//! These tests verify that `build_plan` correctly expands client trees //! (primitive, fallback, round-robin) into flat lists of `OrchestrationStep`s, -//! and that `baml.llm.wrap_with_retry` applies the correct retry logic with -//! exponential backoff delays. +//! with the correct retry logic and exponential backoff delays. mod common; @@ -82,7 +81,7 @@ fn extract_steps(result: &BexExternalValue) -> Vec<(&str, i64)> { /// A primitive client produces a single-step plan. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_primitive_has_one_step() { let source = r##" client A { @@ -90,13 +89,8 @@ client A { options { model "gpt-4" } } -function F(x: string) -> string { - client A - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + A.build_plan() } "##; @@ -104,10 +98,10 @@ function check_plan() -> baml.llm.OrchestrationStep[] { assert_eq!(extract_steps(&result), vec![("A", 0)]); } -/// A fallback client with two sub-clients produces two steps. +/// Diagnostic: inspect a fallback client's fields. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] -async fn plan_fallback_has_two_steps() { +#[ignore] +async fn diag_fallback_fields() { let source = r##" client A { provider openai @@ -124,13 +118,38 @@ client FB { options { strategy [A, B] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# +function check() -> unknown { + [A, B] +} +"##; + + let result = run(source, "check").await; + eprintln!("diag_fallback_fields result: {result:#?}"); + // Don't assert — just inspect +} + +/// A fallback client with two sub-clients produces two steps. +#[tokio::test] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] +async fn plan_fallback_has_two_steps() { + let source = r##" +client A { + provider openai + options { model "gpt-4" } +} + +client B { + provider openai + options { model "gpt-3.5-turbo" } +} + +client FB { + provider fallback + options { strategy [A, B] } } function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -140,7 +159,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// A fallback client with three sub-clients produces three steps. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_fallback_three_clients() { let source = r##" client A { @@ -163,13 +182,8 @@ client FB { options { strategy [A, B, C] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -180,7 +194,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// A round-robin client with two sub-clients produces a single step /// (it picks one sub-client per invocation). #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_round_robin_has_one_step() { let source = r##" client A { @@ -198,13 +212,8 @@ client RR { options { strategy [A, B] } } -function F(x: string) -> string { - client RR - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + RR.build_plan() } "##; @@ -218,7 +227,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Round-robin honors `options { start N }` for the initial selection. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_round_robin_respects_start_index() { let source = r##" client A { @@ -236,13 +245,8 @@ client RR { options { strategy [A, B] start 1 } } -function F(x: string) -> string { - client RR - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + RR.build_plan() } "##; @@ -256,7 +260,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Planner expansion should be side-effect free; only execution should mutate /// the round-robin counter on the Client. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_round_robin_has_no_runtime_side_effects() { let source = r##" client A { @@ -274,11 +278,6 @@ client RR { options { strategy [A, B] start 0 } } -function F(x: string) -> string { - client RR - prompt #"{{ x }}"# -} - function check_plan_side_effects() -> int { RR.build_plan(); RR.build_plan(); @@ -299,7 +298,7 @@ function check_plan_side_effects() -> int { /// A primitive client with retry(max=2) produces 3 steps (1 original + 2 retries). #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_primitive_with_retry_expands() { let source = r##" retry_policy Retry2 { @@ -315,13 +314,8 @@ client A { options { model "gpt-4" } } -function F(x: string) -> string { - client A - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + A.build_plan() } "##; @@ -335,7 +329,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// A fallback[A, B] with retry(max=1) produces 4 steps: /// attempt 0: [A, B], attempt 1: [A, B] #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_fallback_with_retry_multiplies() { let source = r##" retry_policy Retry1 { @@ -359,11 +353,6 @@ client FB { options { strategy [A, B] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { FB.build_plan() } @@ -383,7 +372,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// First step always has `delay_ms=0`, retry steps have exponential backoff. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_delays_exponential_backoff() { let source = r##" retry_policy ExpBackoff { @@ -399,13 +388,8 @@ client A { options { model "gpt-4" } } -function F(x: string) -> string { - client A - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + A.build_plan() } "##; @@ -419,7 +403,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Delays are capped at `max_delay_ms`. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_delays_capped_at_max() { let source = r##" retry_policy CappedBackoff { @@ -435,13 +419,8 @@ client A { options { model "gpt-4" } } -function F(x: string) -> string { - client A - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + A.build_plan() } "##; @@ -455,7 +434,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Fallback with retry: delays apply uniformly to all sub-client steps in a retry attempt. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_fallback_retry_delays() { let source = r##" retry_policy R { @@ -479,13 +458,8 @@ client FB { options { strategy [A, B] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -503,7 +477,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// A client without `retry_policy` produces steps with all delays = 0. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_no_retry_all_zero_delays() { let source = r##" client A { @@ -521,13 +495,8 @@ client FB { options { strategy [A, B] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -541,7 +510,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// `Fallback[RoundRobin[A, B], C]` produces 2 steps: one from RR (picks one) + C. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_nested_fallback_round_robin() { let source = r##" client A { @@ -569,13 +538,8 @@ client FB { options { strategy [RR, C] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -591,7 +555,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Nested retry: inner client has retry, outer fallback does not. /// `Fallback[A(retry=1), B]` = A, A(retry), B = 3 steps. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_nested_inner_retry() { let source = r##" retry_policy InnerRetry { @@ -615,13 +579,8 @@ client FB { options { strategy [A, B] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -636,7 +595,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Verify that the correct primitive clients appear in a fallback plan. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_step_client_names() { let source = r##" client Primary { @@ -654,13 +613,8 @@ client FB { options { strategy [Primary, Backup] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##; @@ -670,7 +624,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Retry duplicates client names: [A, A, A] for retry=2. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_retry_duplicates_client_names() { let source = r##" retry_policy R { @@ -683,13 +637,8 @@ client A { options { model "gpt-4" } } -function F(x: string) -> string { - client A - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + A.build_plan() } "##; @@ -706,7 +655,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// RR { A, B } with retry=1 should produce [A(0), B(delay)] — NOT [A(0), A(delay)]. /// This matches legacy behavior where retry re-expands the strategy. #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_round_robin_retry_rotates() { let source = r##" retry_policy R { @@ -730,13 +679,8 @@ client RR { options { strategy [A, B] } } -function F(x: string) -> string { - client RR - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + RR.build_plan() } "##; @@ -757,7 +701,7 @@ function check_plan() -> baml.llm.OrchestrationStep[] { /// Fallback(retry=1) { RR { A, B }, C } should produce: /// attempt 0: [RR→A, C], attempt 1: [RR→B, C] #[tokio::test] -#[ignore = "compiler2: baml.llm orchestration API (build_plan/wrap_with_retry) not yet fully wired up in compiler2"] +#[ignore = "bex_vm: FunctionRef from SysOps returns Object(Any) instead of Object(Function(Callable))"] async fn plan_fallback_with_rr_child_retry_rotates() { let source = r##" retry_policy R { @@ -791,13 +735,8 @@ client FB { options { strategy [RR, C] } } -function F(x: string) -> string { - client FB - prompt #"{{ x }}"# -} - function check_plan() -> baml.llm.OrchestrationStep[] { - baml.llm.build_plan(F) + FB.build_plan() } "##;