-
Notifications
You must be signed in to change notification settings - Fork 401
baml-language/fix: use build_plan for llm functions, delete execute #3279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: canary
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,6 @@ class OrchestrationStep { | |
| delay_ms int | ||
| } | ||
|
|
||
| class ExecutionResult { | ||
| ok bool | ||
| value unknown | ||
| } | ||
|
|
||
| class ExecutionContext { | ||
| jinja_string string | ||
| args map<string, unknown> | ||
|
|
@@ -58,6 +53,19 @@ function call_llm_function<T>(client: Client, function_name: string, args: map<s | |
| function_name: function_name, | ||
| }; | ||
|
|
||
| let result: T = client.execute(context, 0); | ||
| result | ||
| let steps = client.build_plan(); | ||
| client.advance_round_robin(); | ||
|
|
||
| for (let step in steps) { | ||
| if (step.delay_ms > 0) { | ||
| root.sys.sleep(step.delay_ms); | ||
| } | ||
|
|
||
| let result: T = execute_step(step, context) catch (e) { | ||
| _ => { continue; } | ||
| }; | ||
| return result; | ||
| } | ||
|
|
||
| throw root.errors.DevOther { message: "All orchestration steps failed" }; | ||
|
Comment on lines
+64
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't treat every
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,20 @@ class Client { | |
| } | ||
| } | ||
|
|
||
| function advance_round_robin(self) -> void { | ||
| match (self.client_type) { | ||
| ClientType.Primitive => {}, | ||
| ClientType.Fallback => { | ||
| for (let sub in self.sub_clients) { | ||
| sub.advance_round_robin(); | ||
| } | ||
| }, | ||
| ClientType.RoundRobin => { | ||
| self.counter += 1; | ||
| }, | ||
| } | ||
| } | ||
|
Comment on lines
+35
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Advance the selected round-robin child too.
|
||
|
|
||
| function get_constructor( | ||
| self | ||
| ) -> () -> PrimitiveClient throws root.errors.InvalidArgument { | ||
|
|
@@ -120,109 +134,6 @@ class Client { | |
| } | ||
| } | ||
|
|
||
| function execute<T>( | ||
| self, | ||
| context: ExecutionContext, | ||
| inherited_delay_ms: int, | ||
| ) -> T { | ||
| match (self.retry) { | ||
| r: RetryPolicy => { | ||
| let current_delay = r.initial_delay_ms + 0.0 | ||
|
|
||
| for (let attempt = 0; attempt <= r.max_retries; attempt += 1) { | ||
| let attempt_delay = inherited_delay_ms | ||
| if (attempt > 0) { | ||
| attempt_delay = root.math.trunc(current_delay) | ||
| let next = current_delay * r.multiplier | ||
| if (next > r.max_delay_ms + 0.0) { | ||
| current_delay = r.max_delay_ms + 0.0 | ||
| } else { | ||
| current_delay = next | ||
| } | ||
| } | ||
|
|
||
| if (attempt == r.max_retries) { | ||
| attempt_delay = inherited_delay_ms | ||
| } | ||
|
|
||
| let result2: T = self.execute_once( | ||
| context, | ||
| attempt_delay, | ||
| ) catch (e) { | ||
| _ => { continue; } | ||
| }; | ||
|
|
||
| return result2; | ||
| } | ||
|
|
||
| throw root.errors.DevOther { message: "All orchestration steps failed" }; | ||
| } | ||
| null => { | ||
| let result: T = self.execute_once( | ||
| context, | ||
| inherited_delay_ms, | ||
| ); | ||
| return result; | ||
| }, | ||
| } | ||
| } | ||
|
|
||
| function execute_once<T>( | ||
| self, | ||
| context: ExecutionContext, | ||
| active_delay_ms: int, | ||
| ) -> T { | ||
| match (self.client_type) { | ||
| ClientType.Primitive => { | ||
| let resolve_fn = self.get_constructor() | ||
| let primitive = resolve_fn() | ||
|
|
||
| let prompt = primitive.render_prompt(context.jinja_string, context.args) | ||
| let specialized = primitive.specialize_prompt(prompt) | ||
| let http_request = primitive.build_request(specialized) | ||
| let http_response = root.http.send(http_request) | ||
|
|
||
| if (http_response.ok()) { | ||
| let body = http_response.text() | ||
| let return_type = get_return_type(context.function_name) | ||
| let result: T = primitive.parse(body, return_type) catch (e) { | ||
| _ => { | ||
| if (active_delay_ms > 0) { | ||
| root.sys.sleep(active_delay_ms) | ||
| } | ||
| throw e; | ||
| } | ||
| }; | ||
| result | ||
| } else { | ||
| throw root.errors.DevOther { message: "HTTP request failed" }; | ||
| } | ||
| } | ||
|
|
||
| ClientType.Fallback => { | ||
| for (let sub in self.sub_clients) { | ||
| let result2: T = sub.execute( | ||
| context, | ||
| active_delay_ms, | ||
| ) catch (e) { | ||
| _ => { continue; } | ||
| }; | ||
| return result2; | ||
| } | ||
| throw root.errors.DevOther { message: "All orchestration steps failed" }; | ||
| } | ||
|
|
||
| ClientType.RoundRobin => { | ||
| let idx = self.counter % self.sub_clients.length() | ||
| self.counter += 1 | ||
| let result3: T = self.sub_clients[idx].execute( | ||
| context, | ||
| active_delay_ms, | ||
| ); | ||
| return result3; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| class PrimitiveClientOptions { | ||
|
|
@@ -285,6 +196,24 @@ class PrimitiveClient { | |
| } | ||
| } | ||
|
|
||
| function execute_step<T>( | ||
| step: OrchestrationStep, | ||
| context: ExecutionContext, | ||
| ) -> T { | ||
| let prompt = step.primitive_client.render_prompt(context.jinja_string, context.args); | ||
| let specialized = step.primitive_client.specialize_prompt(prompt); | ||
| let http_request = step.primitive_client.build_request(specialized); | ||
| let http_response = root.http.send(http_request); | ||
|
|
||
| if (http_response.ok()) { | ||
| let body = http_response.text(); | ||
| let return_type = get_return_type(context.function_name); | ||
| step.primitive_client.parse(body, return_type) | ||
| } else { | ||
| throw root.errors.DevOther { message: "HTTP request failed" }; | ||
| } | ||
| } | ||
|
|
||
| function get_jinja_template(function_name: string) -> string throws root.errors.InvalidArgument { | ||
| $rust_io_function | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2258,6 +2258,32 @@ impl<'db> TypeInferenceBuilder<'db> { | |
| ), | ||
| } | ||
| } | ||
| Definition::Let(let_loc) => { | ||
| // Determine type from the let-binding's origin. | ||
| let db = self.context.db(); | ||
| let item_tree = | ||
| baml_compiler2_hir::file_item_tree(db, let_loc.file(db)); | ||
| let let_data = &item_tree[let_loc.id(db)]; | ||
| match let_data.origin { | ||
| baml_compiler2_ast::ast::LetOrigin::Client => { | ||
| // client<llm> declarations produce Client instances. | ||
| Ty::Class(crate::ty::QualifiedTypeName::new( | ||
| baml_base::Name::new("baml"), | ||
| vec![baml_base::Name::new("llm")], | ||
| baml_base::Name::new("Client"), | ||
| )) | ||
| } | ||
| baml_compiler2_ast::ast::LetOrigin::RetryPolicy => { | ||
| // retry_policy declarations produce RetryPolicy instances. | ||
| Ty::Class(crate::ty::QualifiedTypeName::new( | ||
| baml_base::Name::new("baml"), | ||
| vec![baml_base::Name::new("llm")], | ||
| baml_base::Name::new("RetryPolicy"), | ||
| )) | ||
| } | ||
| _ => Ty::Unknown, | ||
| } | ||
|
Comment on lines
+2261
to
+2285
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle let-bound globals in package-qualified paths too. This branch only fixes bare identifiers through |
||
| } | ||
| _ => Ty::Unknown, | ||
| } | ||
| } else if let Some(def) = self.package_items.lookup_type(&lookup_path) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
call_llm_functionnow mutates round-robin state before any step runs, which breaks fallback behavior when an earlier branch succeeds. Becausebuild_plan()flattens all fallback branches but the loop returns on the first successful step, many planned steps are never attempted; howeverclient.advance_round_robin()still increments counters in those untouched subtrees, so future calls can skip providers that were never actually used. This is a regression from the previous execute-on-visit behavior and changes routing deterministically forFallback[..., RoundRobin[...]]clients.Useful? React with 👍 / 👎.