Skip to content

Commit 51fdb81

Browse files
committed
Allow passing full Client Registry
1 parent 0d680e9 commit 51fdb81

File tree

8 files changed

+423
-2
lines changed

8 files changed

+423
-2
lines changed

lib/baml_elixir/client.ex

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,18 @@ defmodule BamlElixir.Client do
452452
defp prepare_opts(opts) do
453453
path = opts[:path] || "baml_src"
454454
collectors = (opts[:collectors] || []) |> Enum.map(fn collector -> collector.reference end)
455-
client_registry = opts[:llm_client] && %{primary: opts[:llm_client]}
455+
456+
client_registry =
457+
if opts[:client_registry] do
458+
opts[:client_registry]
459+
else
460+
if opts[:llm_client] do
461+
%{primary: opts[:llm_client]}
462+
else
463+
nil
464+
end
465+
end
466+
456467
{path, collectors, client_registry, opts[:tb]}
457468
end
458469

mix.exs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ defmodule BamlElixir.MixProject do
1010
version: @version,
1111
elixir: "~> 1.17",
1212
start_permanent: Mix.env() == :prod,
13+
elixirc_paths: elixirc_paths(Mix.env()),
1314
deps: deps(),
1415
package: package()
1516
]
@@ -23,10 +24,14 @@ defmodule BamlElixir.MixProject do
2324
end
2425

2526
# Run "mix help deps" to learn about dependencies.
27+
defp elixirc_paths(:test), do: ["lib", "test/support"]
28+
defp elixirc_paths(_), do: ["lib"]
29+
2630
defp deps do
2731
[
2832
{:rustler, "~> 0.36.1", optional: true},
2933
{:rustler_precompiled, "~> 0.8"},
34+
{:mox, "~> 1.1", only: :test},
3035
{:ex_doc, ">= 0.0.0", only: :dev, runtime: false}
3136
]
3237
end

mix.lock

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
"makeup": {:hex, :makeup, "1.2.1", "e90ac1c65589ef354378def3ba19d401e739ee7ee06fb47f94c687016e3713d1", [:mix], [{:nimble_parsec, "~> 1.4", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "d36484867b0bae0fea568d10131197a4c2e47056a6fbe84922bf6ba71c8d17ce"},
77
"makeup_elixir": {:hex, :makeup_elixir, "1.0.1", "e928a4f984e795e41e3abd27bfc09f51db16ab8ba1aebdba2b3a575437efafc2", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "7284900d412a3e5cfd97fdaed4f5ed389b8f2b4cb49efc0eb3bd10e2febf9507"},
88
"makeup_erlang": {:hex, :makeup_erlang, "1.0.2", "03e1804074b3aa64d5fad7aa64601ed0fb395337b982d9bcf04029d68d51b6a7", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "af33ff7ef368d5893e4a267933e7744e46ce3cf1f61e2dccf53a111ed3aa3727"},
9+
"mox": {:hex, :mox, "1.2.0", "a2cd96b4b80a3883e3100a221e8adc1b98e4c3a332a8fc434c39526babafd5b3", [:mix], [{:nimble_ownership, "~> 1.0", [hex: :nimble_ownership, repo: "hexpm", optional: false]}], "hexpm", "c7b92b3cc69ee24a7eeeaf944cd7be22013c52fcb580c1f33f50845ec821089a"},
10+
"nimble_ownership": {:hex, :nimble_ownership, "1.0.2", "fa8a6f2d8c592ad4d79b2ca617473c6aefd5869abfa02563a77682038bf916cf", [:mix], [], "hexpm", "098af64e1f6f8609c6672127cfe9e9590a5d3fcdd82bc17a377b8692fd81a879"},
911
"nimble_parsec": {:hex, :nimble_parsec, "1.4.2", "8efba0122db06df95bfaa78f791344a89352ba04baedd3849593bfce4d0dc1c6", [:mix], [], "hexpm", "4b21398942dda052b403bbe1da991ccd03a053668d147d53fb8c4e0efe09c973"},
1012
"rustler": {:hex, :rustler, "0.36.1", "2d4b1ff57ea2789a44756a40dbb5fbb73c6ee0a13d031dcba96d0a5542598a6a", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.7", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "f3fba4ad272970e0d1bc62972fc4a99809651e54a125c5242de9bad4574b2d02"},
1113
"rustler_precompiled": {:hex, :rustler_precompiled, "0.8.2", "5f25cbe220a8fac3e7ad62e6f950fcdca5a5a5f8501835d2823e8c74bf4268d5", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "63d1bd5f8e23096d1ff851839923162096364bac8656a4a3c00d1fff8e83ee0a"},

native/baml_elixir/src/lib.rs

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use baml_runtime::client_registry::ClientRegistry;
1+
use baml_runtime::client_registry::{ClientProperty, ClientProvider, ClientRegistry};
22
use baml_runtime::tracingv2::storage::storage::Collector;
33
use baml_runtime::type_builder::TypeBuilder;
44
use baml_runtime::{BamlRuntime, FunctionResult, RuntimeContextManager, TripWire};
@@ -12,6 +12,7 @@ use rustler::{
1212
};
1313
use std::collections::HashMap;
1414
use std::path::Path;
15+
use std::str::FromStr;
1516
use std::sync::Arc;
1617
mod atoms {
1718
rustler::atoms! {
@@ -83,6 +84,73 @@ fn term_to_baml_value<'a>(term: Term<'a>) -> Result<BamlValue, Error> {
8384
))))
8485
}
8586

87+
fn term_to_optional_string(term: Term) -> Result<Option<String>, Error> {
88+
if term.is_atom() && term.decode::<rustler::Atom>()? == atom::nil() {
89+
Ok(None)
90+
} else {
91+
Ok(Some(term_to_string(term)?))
92+
}
93+
}
94+
95+
fn term_to_baml_map(term: Term) -> Result<BamlMap<String, BamlValue>, Error> {
96+
if term.is_atom() && term.decode::<rustler::Atom>()? == atom::nil() {
97+
return Ok(BamlMap::new());
98+
}
99+
if !term.is_map() {
100+
return Err(Error::Term(Box::new("Expected a map")));
101+
}
102+
let mut map = BamlMap::new();
103+
for (key_term, value_term) in
104+
MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid map")))?
105+
{
106+
let key = term_to_string(key_term)?;
107+
let value = term_to_baml_value(value_term)?;
108+
map.insert(key, value);
109+
}
110+
Ok(map)
111+
}
112+
113+
fn term_to_client_property(term: Term, name_override: Option<String>) -> Result<ClientProperty, Error> {
114+
if !term.is_map() {
115+
return Err(Error::Term(Box::new("Client must be a map")));
116+
}
117+
118+
let mut name: Option<String> = name_override;
119+
let mut provider: Option<ClientProvider> = None;
120+
let mut retry_policy: Option<String> = None;
121+
let mut options: BamlMap<String, BamlValue> = BamlMap::new();
122+
123+
let iter = MapIterator::new(term).ok_or(Error::Term(Box::new("Invalid client map")))?;
124+
for (key_term, value_term) in iter {
125+
let key = term_to_string(key_term)?;
126+
match key.as_str() {
127+
"name" => {
128+
name = Some(term_to_string(value_term)?);
129+
}
130+
"provider" => {
131+
let provider_str = term_to_string(value_term)?;
132+
provider = Some(
133+
ClientProvider::from_str(&provider_str).map_err(|e| {
134+
Error::Term(Box::new(format!("Invalid client provider: {e}")))
135+
})?,
136+
);
137+
}
138+
"retry_policy" => {
139+
retry_policy = term_to_optional_string(value_term)?;
140+
}
141+
"options" => {
142+
options = term_to_baml_map(value_term)?;
143+
}
144+
_ => {}
145+
}
146+
}
147+
148+
let name = name.ok_or(Error::Term(Box::new("Client missing required key: name")))?;
149+
let provider = provider.ok_or(Error::Term(Box::new("Client missing required key: provider")))?;
150+
151+
Ok(ClientProperty::new(name, provider, retry_policy, options))
152+
}
153+
86154
fn baml_value_to_term<'a>(env: Env<'a>, value: &BamlValue) -> NifResult<Term<'a>> {
87155
match value {
88156
BamlValue::String(s) => Ok(s.encode(env)),
@@ -218,6 +286,32 @@ fn prepare_request<'a>(
218286
if key == "primary" {
219287
let primary = term_to_string(value_term)?;
220288
registry.set_primary(primary);
289+
} else if key == "clients" {
290+
// Accept either:
291+
// - a list of client maps: [%{name: ..., provider: ..., ...}, ...]
292+
// - a map of name => client map: %{ "name" => %{provider: ..., ...}, ... }
293+
if let Ok(list) = value_term.decode::<Vec<Term>>() {
294+
for client_term in list {
295+
let client = term_to_client_property(client_term, None)?;
296+
registry.add_client(client);
297+
}
298+
} else if value_term.is_map() {
299+
let client_iter = MapIterator::new(value_term)
300+
.ok_or(Error::Term(Box::new("Invalid clients map")))?;
301+
for (name_term, client_term) in client_iter {
302+
let name = term_to_string(name_term)?;
303+
let client = term_to_client_property(client_term, Some(name))?;
304+
registry.add_client(client);
305+
}
306+
} else if value_term.is_atom()
307+
&& value_term.decode::<rustler::Atom>()? == atom::nil()
308+
{
309+
// allow nil clients
310+
} else {
311+
return Err(Error::Term(Box::new(
312+
"Client registry clients must be a list, a map, or nil",
313+
)));
314+
}
221315
}
222316
}
223317
Some(registry)

test/baml_elixir_test.exs

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,113 @@ defmodule BamlElixirTest do
22
use ExUnit.Case
33
use BamlElixir.Client, path: "test/baml_src"
44

5+
import Mox
6+
57
alias BamlElixir.TypeBuilder
68

79
doctest BamlElixir
810

11+
setup :set_mox_from_context
12+
setup :verify_on_exit!
13+
14+
@tag :client_registry
15+
test "client_registry supports clients key (list form)" do
16+
client_registry = %{
17+
primary: "InjectedClient",
18+
clients: [
19+
%{
20+
name: "InjectedClient",
21+
provider: "definitely-not-a-provider",
22+
retry_policy: nil,
23+
options: %{model: "gpt-4o-mini"}
24+
}
25+
]
26+
}
27+
28+
# parse: false to avoid any parsing work; we want to exercise registry decoding/validation
29+
assert {:error, msg} =
30+
BamlElixirTest.WhichModel.call(%{}, %{client_registry: client_registry, parse: false})
31+
32+
assert msg =~ "Invalid client provider"
33+
end
34+
35+
@tag :client_registry
36+
test "client_registry supports clients key (map form)" do
37+
client_registry = %{
38+
primary: "InjectedClient",
39+
clients: %{
40+
"InjectedClient" => %{
41+
provider: "definitely-not-a-provider",
42+
retry_policy: nil,
43+
options: %{model: "gpt-4o-mini"}
44+
}
45+
}
46+
}
47+
48+
assert {:error, msg} =
49+
BamlElixirTest.WhichModel.call(%{}, %{client_registry: client_registry, parse: false})
50+
51+
assert msg =~ "Invalid client provider"
52+
end
53+
54+
@tag :client_registry
55+
test "client_registry can inject and select a client not present in the BAML files (success path)" do
56+
BamlElixirTest.FakeOpenAIServer.expect_chat_completion("GPT")
57+
base_url = BamlElixirTest.FakeOpenAIServer.start_base_url()
58+
59+
client_registry = %{
60+
primary: "InjectedClient",
61+
clients: [
62+
%{
63+
name: "InjectedClient",
64+
provider: "openai-generic",
65+
retry_policy: nil,
66+
options: %{
67+
base_url: base_url,
68+
api_key: "test-key",
69+
model: "gpt-4o-mini"
70+
}
71+
}
72+
]
73+
}
74+
75+
# This function declares `client GPT4` in the .baml file, so success here proves
76+
# `client_registry.primary` overrides the static client selection.
77+
assert {:ok, "GPT"} =
78+
BamlElixirTest.WhichModelUnion.call(%{}, %{client_registry: client_registry})
79+
end
80+
81+
@tag :client_registry
82+
test "client_registry passes clients[].options.headers into the HTTP request" do
83+
BamlElixirTest.FakeOpenAIServer.expect_chat_completion("GPT", %{
84+
"x-test-header" => "hello-from-elixir"
85+
})
86+
87+
base_url = BamlElixirTest.FakeOpenAIServer.start_base_url()
88+
89+
client_registry = %{
90+
primary: "InjectedClient",
91+
clients: [
92+
%{
93+
name: "InjectedClient",
94+
provider: "openai-generic",
95+
retry_policy: nil,
96+
options: %{
97+
base_url: base_url,
98+
api_key: "test-key",
99+
model: "gpt-4o-mini",
100+
headers: %{
101+
"x-test-header" => "hello-from-elixir"
102+
}
103+
}
104+
}
105+
]
106+
}
107+
108+
assert {:ok, "GPT"} =
109+
BamlElixirTest.WhichModelUnion.call(%{}, %{client_registry: client_registry})
110+
end
111+
9112
test "parses into a struct" do
10113
assert {:ok, %BamlElixirTest.Person{name: "John Doe", age: 28}} =
11114
BamlElixirTest.ExtractPerson.call(%{info: "John Doe, 28, Engineer"})

0 commit comments

Comments
 (0)