Skip to content

Commit 11ce9f1

Browse files
committed
baml-lanugage/fix: use build_plan for llm functions, delete execute
1 parent c574bff commit 11ce9f1

File tree

4 files changed

+109
-237
lines changed

4 files changed

+109
-237
lines changed

baml_language/crates/baml_builtins2/baml_std/baml/llm.baml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@ class OrchestrationStep {
33
delay_ms int
44
}
55

6-
class ExecutionResult {
7-
ok bool
8-
value unknown
9-
}
10-
116
class ExecutionContext {
127
jinja_string string
138
args map<string, unknown>
@@ -58,6 +53,19 @@ function call_llm_function<T>(client: Client, function_name: string, args: map<s
5853
function_name: function_name,
5954
};
6055

61-
let result: T = client.execute(context, 0);
62-
result
56+
let steps = client.build_plan();
57+
client.advance_round_robin();
58+
59+
for (let step in steps) {
60+
if (step.delay_ms > 0) {
61+
root.sys.sleep(step.delay_ms);
62+
}
63+
64+
let result: T = execute_step(step, context) catch (e) {
65+
_ => { continue; }
66+
};
67+
return result;
68+
}
69+
70+
throw root.errors.DevOther { message: "All orchestration steps failed" };
6371
}

baml_language/crates/baml_builtins2/baml_std/baml/llm_types.baml

Lines changed: 32 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ class Client {
3232
}
3333
}
3434

35+
function advance_round_robin(self) -> void {
36+
match (self.client_type) {
37+
ClientType.Primitive => {},
38+
ClientType.Fallback => {
39+
for (let sub in self.sub_clients) {
40+
sub.advance_round_robin();
41+
}
42+
},
43+
ClientType.RoundRobin => {
44+
self.counter += 1;
45+
},
46+
}
47+
}
48+
3549
function get_constructor(
3650
self
3751
) -> () -> PrimitiveClient throws root.errors.InvalidArgument {
@@ -120,109 +134,6 @@ class Client {
120134
}
121135
}
122136

123-
function execute<T>(
124-
self,
125-
context: ExecutionContext,
126-
inherited_delay_ms: int,
127-
) -> T {
128-
match (self.retry) {
129-
r: RetryPolicy => {
130-
let current_delay = r.initial_delay_ms + 0.0
131-
132-
for (let attempt = 0; attempt <= r.max_retries; attempt += 1) {
133-
let attempt_delay = inherited_delay_ms
134-
if (attempt > 0) {
135-
attempt_delay = root.math.trunc(current_delay)
136-
let next = current_delay * r.multiplier
137-
if (next > r.max_delay_ms + 0.0) {
138-
current_delay = r.max_delay_ms + 0.0
139-
} else {
140-
current_delay = next
141-
}
142-
}
143-
144-
if (attempt == r.max_retries) {
145-
attempt_delay = inherited_delay_ms
146-
}
147-
148-
let result2: T = self.execute_once(
149-
context,
150-
attempt_delay,
151-
) catch (e) {
152-
_ => { continue; }
153-
};
154-
155-
return result2;
156-
}
157-
158-
throw root.errors.DevOther { message: "All orchestration steps failed" };
159-
}
160-
null => {
161-
let result: T = self.execute_once(
162-
context,
163-
inherited_delay_ms,
164-
);
165-
return result;
166-
},
167-
}
168-
}
169-
170-
function execute_once<T>(
171-
self,
172-
context: ExecutionContext,
173-
active_delay_ms: int,
174-
) -> T {
175-
match (self.client_type) {
176-
ClientType.Primitive => {
177-
let resolve_fn = self.get_constructor()
178-
let primitive = resolve_fn()
179-
180-
let prompt = primitive.render_prompt(context.jinja_string, context.args)
181-
let specialized = primitive.specialize_prompt(prompt)
182-
let http_request = primitive.build_request(specialized)
183-
let http_response = root.http.send(http_request)
184-
185-
if (http_response.ok()) {
186-
let body = http_response.text()
187-
let return_type = get_return_type(context.function_name)
188-
let result: T = primitive.parse(body, return_type) catch (e) {
189-
_ => {
190-
if (active_delay_ms > 0) {
191-
root.sys.sleep(active_delay_ms)
192-
}
193-
throw e;
194-
}
195-
};
196-
result
197-
} else {
198-
throw root.errors.DevOther { message: "HTTP request failed" };
199-
}
200-
}
201-
202-
ClientType.Fallback => {
203-
for (let sub in self.sub_clients) {
204-
let result2: T = sub.execute(
205-
context,
206-
active_delay_ms,
207-
) catch (e) {
208-
_ => { continue; }
209-
};
210-
return result2;
211-
}
212-
throw root.errors.DevOther { message: "All orchestration steps failed" };
213-
}
214-
215-
ClientType.RoundRobin => {
216-
let idx = self.counter % self.sub_clients.length()
217-
self.counter += 1
218-
let result3: T = self.sub_clients[idx].execute(
219-
context,
220-
active_delay_ms,
221-
);
222-
return result3;
223-
}
224-
}
225-
}
226137
}
227138

228139
class PrimitiveClientOptions {
@@ -285,6 +196,24 @@ class PrimitiveClient {
285196
}
286197
}
287198

199+
function execute_step<T>(
200+
step: OrchestrationStep,
201+
context: ExecutionContext,
202+
) -> T {
203+
let prompt = step.primitive_client.render_prompt(context.jinja_string, context.args);
204+
let specialized = step.primitive_client.specialize_prompt(prompt);
205+
let http_request = step.primitive_client.build_request(specialized);
206+
let http_response = root.http.send(http_request);
207+
208+
if (http_response.ok()) {
209+
let body = http_response.text();
210+
let return_type = get_return_type(context.function_name);
211+
step.primitive_client.parse(body, return_type)
212+
} else {
213+
throw root.errors.DevOther { message: "HTTP request failed" };
214+
}
215+
}
216+
288217
function get_jinja_template(function_name: string) -> string throws root.errors.InvalidArgument {
289218
$rust_io_function
290219
}

baml_language/crates/baml_compiler2_tir/src/builder.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2235,6 +2235,32 @@ impl<'db> TypeInferenceBuilder<'db> {
22352235
// reported here — they'll be reported at the definition site.
22362236
ty
22372237
}
2238+
Definition::Let(let_loc) => {
2239+
// Determine type from the let-binding's origin.
2240+
let db = self.context.db();
2241+
let item_tree =
2242+
baml_compiler2_hir::file_item_tree(db, let_loc.file(db));
2243+
let let_data = &item_tree[let_loc.id(db)];
2244+
match let_data.origin {
2245+
baml_compiler2_ast::ast::LetOrigin::Client => {
2246+
// client<llm> declarations produce Client instances.
2247+
Ty::Class(crate::ty::QualifiedTypeName::new(
2248+
baml_base::Name::new("baml"),
2249+
vec![baml_base::Name::new("llm")],
2250+
baml_base::Name::new("Client"),
2251+
))
2252+
}
2253+
baml_compiler2_ast::ast::LetOrigin::RetryPolicy => {
2254+
// retry_policy declarations produce RetryPolicy instances.
2255+
Ty::Class(crate::ty::QualifiedTypeName::new(
2256+
baml_base::Name::new("baml"),
2257+
vec![baml_base::Name::new("llm")],
2258+
baml_base::Name::new("RetryPolicy"),
2259+
))
2260+
}
2261+
_ => Ty::Unknown,
2262+
}
2263+
}
22382264
_ => Ty::Unknown,
22392265
}
22402266
} else if let Some(def) = self.package_items.lookup_type(&lookup_path) {

0 commit comments

Comments
 (0)