Skip to content

Commit 343318a

Browse files
committed
Generate structs, types and calls
1 parent 1616338 commit 343318a

File tree

3 files changed

+286
-46
lines changed

3 files changed

+286
-46
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ target/
3131
/priv/native/
3232
baml_client/
3333
baml_client.tmp
34+
/lib/baml_elixir/test.ex

lib/baml_elixir/client.ex

Lines changed: 187 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ defmodule BamlElixir.Client do
2727
defmacro __using__(opts) do
2828
path = Keyword.get(opts, :path, "baml_src")
2929

30+
# Get BAML types
31+
baml_types = BamlElixir.Native.parse_baml(path)
32+
baml_class_types = baml_types[:classes]
33+
baml_enum_types = baml_types[:enums]
34+
baml_functions = baml_types[:functions]
35+
3036
quote do
3137
import BamlElixir.Client
3238

33-
# Get BAML types
34-
baml_types = BamlElixir.Native.parse_baml(unquote(path))
35-
baml_class_types = baml_types[:classes]
36-
baml_enum_types = baml_types[:enums]
37-
3839
# Generate types
39-
BamlElixir.Client.generate_class_types(__MODULE__, baml_class_types)
40-
BamlElixir.Client.generate_enum_types(__MODULE__, baml_enum_types)
40+
generate_class_types(unquote(baml_class_types))
41+
generate_enum_types(unquote(baml_enum_types))
42+
generate_call_function_clauses(unquote(baml_functions), unquote(path))
4143
end
4244
end
4345

@@ -64,7 +66,14 @@ defmodule BamlElixir.Client do
6466
def call(function_name, args, opts \\ %{}) do
6567
{path, collectors, client_registry} = prepare_opts(opts)
6668

67-
BamlElixir.Native.call(function_name, args, path, collectors, client_registry)
69+
with {:ok, result} <-
70+
BamlElixir.Native.call(function_name, args, path, collectors, client_registry) do
71+
if opts[:parse] != false do
72+
parse_result(result, opts[:prefix])
73+
else
74+
result
75+
end
76+
end
6877
end
6978

7079
@doc """
@@ -85,8 +94,8 @@ defmodule BamlElixir.Client do
8594
stream = BamlElixir.Client.stream!("MyFunction", %{arg1: "value"})
8695
Enum.each(stream, fn result -> IO.inspect(result) end)
8796
"""
88-
@spec stream!(String.t(), map(), map()) :: Enumerable.t()
89-
def stream!(function_name, args, opts \\ %{}) do
97+
@spec create_stream!(String.t(), map(), map()) :: Enumerable.t()
98+
def create_stream!(function_name, args, opts \\ %{}) do
9099
{path, collectors, client_registry} = prepare_opts(opts)
91100

92101
Stream.resource(
@@ -118,28 +127,45 @@ defmodule BamlElixir.Client do
118127
)
119128
end
120129

130+
def stream!(function_name, args, callback, opts \\ %{}) do
131+
create_stream!(function_name, args, opts)
132+
|> Stream.map(fn result ->
133+
result =
134+
if opts[:parse] != false do
135+
parse_result(result, opts[:prefix])
136+
else
137+
result
138+
end
139+
140+
callback.(result)
141+
end)
142+
|> Stream.run()
143+
144+
callback.(:done)
145+
end
146+
121147
@doc false
122-
def generate_class_types(module, class_types) do
148+
defmacro generate_class_types(class_types) do
149+
module = __CALLER__.module
150+
123151
for {type_name, fields} <- class_types do
124152
field_names = get_field_names(fields)
125-
field_types = get_field_types(fields)
153+
field_types = get_field_types(fields, __CALLER__)
126154
module_name = Module.concat([module, type_name])
127155

128-
Module.create(
129-
module_name,
130-
quote do
156+
quote do
157+
defmodule unquote(module_name) do
131158
defstruct unquote(field_names)
132159
@type t :: %__MODULE__{unquote_splicing(field_types)}
133-
end,
134-
Macro.Env.location(__ENV__)
135-
)
136-
137-
IO.puts("Generated BAML class module: #{inspect(module_name)}")
160+
end
161+
end
138162
end
139163
end
140164

141165
@doc false
142-
def generate_enum_types(module, enum_types) do
166+
defmacro generate_enum_types(enum_types) do
167+
module = __CALLER__.module
168+
143169
for {enum_name, variants} <- enum_types do
144170
variant_atoms = Enum.map(variants, &String.to_atom/1)
145171
module_name = Module.concat([module, enum_name])
@@ -149,15 +175,132 @@ defmodule BamlElixir.Client do
149175
{:|, [], [atom, acc]}
150176
end)
151177

152-
Module.create(
153-
module_name,
154-
quote do
178+
quote do
179+
defmodule unquote(module_name) do
155180
@type t :: unquote(union_type)
156-
end,
157-
Macro.Env.location(__ENV__)
158-
)
181+
end
182+
end
183+
end
184+
end
185+
186+
@doc false
187+
defmacro generate_call_function_clauses(functions, path) do
188+
for {function_name, function_info} <- functions do
189+
function_atom = String.to_atom("call#{function_name}")
190+
191+
param_types =
192+
for {param_name, param_type} <- function_info["params"] do
193+
{String.to_atom(param_name), to_elixir_type(param_type, __CALLER__)}
194+
end
195+
196+
typespec =
197+
quote do
198+
@spec unquote(function_atom)(%{unquote_splicing(param_types)}, map()) ::
199+
{:ok, unquote(to_elixir_type(function_info["return_type"], __CALLER__))}
200+
| {:error, String.t()}
201+
end
202+
203+
function_clause =
204+
quote do
205+
def unquote(function_atom)(args, opts \\ %{}) do
206+
opts = Map.put(opts, :path, unquote(path))
207+
call(unquote(function_name), args, opts)
208+
end
209+
end
210+
211+
quote do
212+
unquote(typespec)
213+
unquote(function_clause)
214+
end
215+
end
216+
end
217+
218+
defp to_elixir_type(type, caller) do
219+
case type do
220+
{:primitive, primitive} ->
221+
case primitive do
222+
:string ->
223+
quote(do: String.t())
224+
225+
:integer ->
226+
quote(do: integer())
227+
228+
:float ->
229+
quote(do: float())
230+
231+
:boolean ->
232+
quote(do: boolean())
233+
234+
nil ->
235+
quote(do: nil)
236+
237+
:media ->
238+
quote(
239+
do:
240+
%{url: String.t()}
241+
| %{url: String.t(), media_type: String.t()}
242+
| %{base64: String.t()}
243+
| %{base64: String.t(), media_type: String.t()}
244+
)
245+
end
246+
247+
{:enum, name} ->
248+
# Convert enum name to module reference with .t()
249+
module = Module.concat([caller.module, name])
250+
quote(do: unquote(module).t())
251+
252+
{:class, name} ->
253+
# Convert class name to module reference with .t()
254+
module = Module.concat([caller.module, name])
255+
quote(do: unquote(module).t())
256+
257+
{:list, inner_type} ->
258+
# Convert to list type
259+
quote(do: [unquote(to_elixir_type(inner_type, caller))])
260+
261+
{:map, key_type, value_type} ->
262+
# Convert to map type
263+
quote(
264+
do: %{
265+
unquote(to_elixir_type(key_type, caller)) =>
266+
unquote(to_elixir_type(value_type, caller))
267+
}
268+
)
269+
270+
{:literal, value} ->
271+
# For literals, use the value directly
272+
case value do
273+
v when is_atom(v) -> v
274+
v when is_integer(v) -> v
275+
v when is_boolean(v) -> v
276+
end
159277

160-
IO.puts("Generated BAML enum module: #{inspect(module_name)}")
278+
{:union, types} ->
279+
# Convert union to pipe operator
280+
[first_type | rest_types] = types
281+
first_ast = to_elixir_type(first_type, caller)
282+
283+
Enum.reduce(rest_types, first_ast, fn type, acc ->
284+
{:|, [], [to_elixir_type(type, caller), acc]}
285+
end)
286+
287+
{:tuple, types} ->
288+
# Convert to tuple type
289+
types_ast = Enum.map(types, &to_elixir_type(&1, caller))
290+
{:{}, [], types_ast}
291+
292+
{:optional, inner_type} ->
293+
# Convert optional to union with nil
294+
{:|, [], [to_elixir_type(inner_type, caller), nil]}
295+
296+
{:alias, name} ->
297+
# For recursive type aliases, use the name with .t()
298+
module = String.to_atom(name)
299+
quote(do: unquote(module).t())
300+
301+
_ ->
302+
# Fallback to any
303+
quote(do: any())
161304
end
162305
end
163306

@@ -167,18 +310,9 @@ defmodule BamlElixir.Client do
167310
end
168311
end
169312

170-
defp get_field_types(fields) do
313+
defp get_field_types(fields, caller) do
171314
for {field_name, field_type} <- fields do
172-
elixir_type =
173-
case field_type do
174-
"string" -> :string
175-
"int" -> :integer
176-
"float" -> :float
177-
"bool" -> :boolean
178-
# For custom types like Company
179-
_ -> :any
180-
end
181-
315+
elixir_type = to_elixir_type(field_type, caller)
182316
{String.to_atom(field_name), elixir_type}
183317
end
184318
end
@@ -189,4 +323,18 @@ defmodule BamlElixir.Client do
189323
client_registry = opts[:llm_client] && %{primary: opts[:llm_client]}
190324
{path, collectors, client_registry}
191325
end
326+
327+
defp parse_result(%{:__baml_class__ => class_name} = result, prefix) do
328+
module = Module.concat(prefix, class_name)
329+
values = Enum.map(result, fn {key, value} -> {key, parse_result(value, prefix)} end)
330+
struct(module, values)
331+
end
332+
333+
defp parse_result(%{:__baml_enum__ => _, :value => value}, _prefix) do
334+
String.to_atom(value)
335+
end
336+
337+
defp parse_result(result, _prefix) do
338+
result
339+
end
192340
end

0 commit comments

Comments
 (0)