Skip to content

Commit b93e87f

Browse files
authored
Refactor containers (#590)
* Refactor containers * Fix warning
1 parent 9fce600 commit b93e87f

File tree

4 files changed

+56
-226
lines changed

4 files changed

+56
-226
lines changed

lib/axon.ex

Lines changed: 49 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ defmodule Axon do
279279
alias __MODULE__, as: Axon
280280
alias Axon.Parameter
281281

282+
import Axon.Shared
283+
282284
require Logger
283285

284286
@type t :: %__MODULE__{}
@@ -380,15 +382,6 @@ defmodule Axon do
380382
}
381383
end
382384

383-
defp split_inputs(:container, [inputs]) do
384-
{inputs, cache} =
385-
deep_map_reduce(inputs, %{}, fn %Axon{output: id, nodes: nodes}, cache ->
386-
{id, Map.merge(nodes, cache)}
387-
end)
388-
389-
{[inputs], [], [:layer], cache}
390-
end
391-
392385
defp split_inputs(_op, inputs) do
393386
Enum.reduce(inputs, {[], [], [], %{}}, fn
394387
%Axon{output: layer_input, nodes: nodes}, {layers, params, args, cache} ->
@@ -704,62 +697,47 @@ defmodule Axon do
704697
@doc type: :special
705698
def container(container, opts \\ []) do
706699
opts = Keyword.validate!(opts, [:name, :meta])
707-
708-
layer(:container, [container], name: opts[:name], meta: opts[:meta], op_name: :container)
700+
{structure_fn, nodes} = destructure(container)
701+
layer(structure_fn, nodes, name: opts[:name], meta: opts[:meta], op_name: :container)
709702
end
710703

711-
# TODO: This should not be duplicated
712-
defp deep_new(%Nx.Tensor{} = x, fun), do: fun.(x)
713-
714-
defp deep_new(x, fun) when is_number(x), do: fun.(x)
715-
716-
defp deep_new(map, fun) do
717-
{cont, :ok} = Nx.Container.traverse(map, :ok, &recur_traverse(&1, &2, fun))
718-
cont
704+
defp destructure(container) do
705+
{structure, {nodes, _}} = recur_destructure(container, {[], 0})
706+
fun = restructure(length(nodes) + 1, structure)
707+
{fun, Enum.reverse(nodes)}
719708
end
720709

721-
defp recur_traverse(item, :ok, fun) do
722-
case item do
723-
%Axon{} = t ->
724-
{fun.(t), :ok}
725-
726-
%{axon: :axon} = t ->
727-
{fun.(t), :ok}
710+
defp recur_destructure(container, acc) do
711+
Nx.Container.traverse(container, acc, fn value, {leaves, idx} ->
712+
case value do
713+
%Axon{} = leaf ->
714+
{idx, {[leaf | leaves], idx + 1}}
728715

729-
container ->
730-
{deep_new(container, fun), :ok}
731-
end
716+
container ->
717+
recur_destructure(container, {leaves, idx})
718+
end
719+
end)
732720
end
733721

734-
defp deep_merge(left, right, fun) do
735-
case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
736-
{merged, []} ->
737-
merged
722+
for i <- 0..128 do
723+
args = Macro.generate_arguments(i, __MODULE__)
738724

739-
{_merged, _leftover} ->
740-
raise ArgumentError,
741-
"unable to merge arguments with incompatible" <>
742-
" structure"
725+
defp restructure(unquote(i), structure) do
726+
fn unquote_splicing(args) ->
727+
args_tuple = {unquote_splicing(args)}
728+
{container, :ok} = recur_restructure(structure, args_tuple)
729+
container
730+
end
743731
end
744732
end
745733

746-
defp leaves(container) do
747-
container
748-
|> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
749-
|> Enum.reverse()
750-
end
751-
752-
defp recur_merge(left, [right | right_leaves], fun) do
753-
case {left, right} do
754-
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
755-
{fun.(left, right), right_leaves}
756-
757-
{%Axon{} = left, %Axon{} = right} ->
758-
{fun.(left, right), right_leaves}
759-
760-
{left, right} ->
761-
{deep_merge(left, right, fun), right_leaves}
762-
end
734+
defp recur_restructure(structure, args_tuple) do
735+
Nx.Container.traverse(structure, :ok, fn value, :ok ->
736+
case value do
737+
idx when is_integer(idx) -> {elem(args_tuple, idx), :ok}
738+
container -> recur_restructure(container, args_tuple)
739+
end
740+
end)
763741
end
764742

765743
@doc """
@@ -3644,35 +3622,31 @@ defmodule Axon do
36443622
end
36453623

36463624
@doc """
3647-
Returns a model's output shape from the given input
3625+
Returns a model's output template from the given input
36483626
template.
3627+
3628+
The output template gives you access to the output shape
3629+
and type of the given input graph.
36493630
"""
36503631
@doc type: :graph
36513632
def get_output_shape(%Axon{} = axon, inputs, opts \\ []) do
36523633
{init_fn, forward_fn} = build(axon, opts ++ [raise_on_none: false])
36533634

3654-
out =
3635+
inputs =
3636+
case inputs do
3637+
%Nx.Tensor{} = input -> Nx.to_template(input)
3638+
inputs when is_map(inputs) -> Map.new(inputs, fn {k, v} -> {k, Nx.to_template(v)} end)
3639+
end
3640+
3641+
fun =
36553642
Nx.Defn.jit(
36563643
fn inputs ->
36573644
forward_fn.(init_fn.(inputs, Axon.ModelState.empty()), inputs)
36583645
end,
36593646
compiler: Axon.Defn
3660-
).(inputs)
3661-
3662-
safe_shape(out)
3663-
end
3664-
3665-
defp safe_shape(container_or_tensor) do
3666-
case container_or_tensor do
3667-
%Axon.None{} = none ->
3668-
none
3669-
3670-
%Nx.Tensor{} = tensor ->
3671-
Nx.shape(tensor)
3647+
)
36723648

3673-
container ->
3674-
deep_new(container, &Nx.shape/1)
3675-
end
3649+
deep_new(apply(fun, [inputs]), &Nx.to_template/1)
36763650
end
36773651

36783652
@doc """
@@ -3850,74 +3824,17 @@ defmodule Axon do
38503824
if MapSet.member?(visited, id) do
38513825
{acc, visited}
38523826
else
3853-
%{op: op, parent: parents} = parent = nodes[id]
3827+
%{parent: parents} = parent = nodes[id]
38543828

38553829
{acc, visited} =
3856-
case op do
3857-
:container ->
3858-
[container] = parents
3859-
3860-
deep_reduce(container, {acc, visited}, fn pid, {acc, visited} ->
3861-
traverse_nodes(pid, nodes, acc, visited)
3862-
end)
3863-
3864-
_ ->
3865-
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
3866-
traverse_nodes(pid, nodes, acc, visited)
3867-
end)
3868-
end
3830+
Enum.reduce(parents, {acc, visited}, fn pid, {acc, visited} ->
3831+
traverse_nodes(pid, nodes, acc, visited)
3832+
end)
38693833

38703834
{[parent | acc], MapSet.put(visited, id)}
38713835
end
38723836
end
38733837

3874-
# TODO: Do not duplicate
3875-
defp deep_reduce(item, acc, fun) when is_integer(item) do
3876-
fun.(item, acc)
3877-
end
3878-
3879-
defp deep_reduce(map, acc, fun) do
3880-
Nx.Container.reduce(map, acc, &recur_deep_reduce(&1, &2, fun))
3881-
end
3882-
3883-
defp recur_deep_reduce(value, acc, fun) do
3884-
case value do
3885-
%Axon{} = val ->
3886-
fun.(val, acc)
3887-
3888-
%Nx.Tensor{} = val ->
3889-
fun.(val, acc)
3890-
3891-
%{axon: :axon} = val ->
3892-
fun.(val, acc)
3893-
3894-
val when is_integer(val) ->
3895-
fun.(val, acc)
3896-
3897-
val ->
3898-
deep_reduce(val, acc, fun)
3899-
end
3900-
end
3901-
3902-
defp deep_map_reduce(leaf, acc, fun) when is_integer(leaf), do: fun.(leaf, acc)
3903-
3904-
defp deep_map_reduce(container, acc, fun) do
3905-
Nx.Container.traverse(container, acc, &recur_deep_map_reduce(&1, &2, fun))
3906-
end
3907-
3908-
defp recur_deep_map_reduce(leaf, acc, fun) do
3909-
case leaf do
3910-
%Axon{} = leaf ->
3911-
fun.(leaf, acc)
3912-
3913-
%Nx.Tensor{} = leaf ->
3914-
fun.(leaf, acc)
3915-
3916-
container ->
3917-
deep_map_reduce(container, acc, fun)
3918-
end
3919-
end
3920-
39213838
@doc """
39223839
Pops the top node off of the graph.
39233840

lib/axon/compiler.ex

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ defmodule Axon.Compiler do
4040
@moduledoc false
4141
require Logger
4242

43-
import Axon.Shared
4443
alias Axon.StatefulOutput
4544

4645
## Init JIT Compilation
@@ -549,72 +548,6 @@ defmodule Axon.Compiler do
549548
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
550549
end
551550

552-
defp recur_model_funs(
553-
%Axon.Node{id: id, op: :container, parent: [parents]},
554-
nodes,
555-
cache_and_counts,
556-
config
557-
) do
558-
{parent_ids, {cache, op_counts, block_cache, model_state_meta}} =
559-
deep_map_reduce(parents, cache_and_counts, &to_model_funs(&1, nodes, &2, config))
560-
561-
op_counts = Map.update(op_counts, :container, 1, fn x -> x + 1 end)
562-
563-
predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
564-
{input, {state, result_cache, none?}} =
565-
deep_map_reduce(
566-
parent_ids,
567-
{state, result_cache, false},
568-
fn parent_id, {state, result_cache, none?} ->
569-
{input, {state, result_cache}} =
570-
call_predict_cache(
571-
parent_id,
572-
params,
573-
inputs,
574-
state,
575-
cache,
576-
result_cache,
577-
fn_stacktrace
578-
)
579-
580-
none? = none? or propagating_none?(input)
581-
{input, {state, result_cache, none?}}
582-
end
583-
)
584-
585-
input = if none?, do: %Axon.None{}, else: input
586-
587-
{input, {state, result_cache}}
588-
end
589-
590-
init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
591-
{parent_template, {parent_params, result_cache, none?}} =
592-
deep_map_reduce(parent_ids, {%{}, result_cache, false}, fn
593-
parent_id, {params, result_cache, none?} ->
594-
{parent_template, {params, result_cache}} =
595-
call_init_cache(
596-
parent_id,
597-
template,
598-
params,
599-
cache,
600-
result_cache,
601-
fn_stacktrace,
602-
keys
603-
)
604-
605-
none? = none? or propagating_none?(parent_template)
606-
{parent_template, {params, result_cache, none?}}
607-
end)
608-
609-
parent_template = if none?, do: %Axon.None{}, else: parent_template
610-
611-
{parent_template, {parent_params, result_cache}}
612-
end
613-
614-
model_funs = %{predict: predict_fun, init: init_fun}
615-
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
616-
end
617-
618551
defp recur_model_funs(
619552
%Axon.Node{
620553
id: id,

lib/axon/display.ex

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ defmodule Axon.Display do
33
Module for rendering various visual representations of Axon models.
44
"""
55

6-
import Axon.Shared
76
alias Axon.Parameter
87

98
@compile {:no_warn_undefined, TableRex.Table}
@@ -94,7 +93,8 @@ defmodule Axon.Display do
9493
defp do_axon_to_rows(
9594
%Axon.Node{
9695
id: id,
97-
op: :container,
96+
op: structure,
97+
op_name: :container,
9898
parent: [parents],
9999
name: name_fn
100100
},
@@ -105,7 +105,7 @@ defmodule Axon.Display do
105105
model_info
106106
) do
107107
{input_names, {cache, op_counts, model_info}} =
108-
deep_map_reduce(parents, {cache, op_counts, model_info}, fn
108+
Enum.map_reduce(parents, {cache, op_counts, model_info}, fn
109109
parent_id, {cache, op_counts, model_info} ->
110110
{_, name, _shape, cache, op_counts, model_info} =
111111
axon_to_rows(parent_id, nodes, templates, cache, op_counts, model_info)
@@ -119,7 +119,7 @@ defmodule Axon.Display do
119119
shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
120120

121121
row = [
122-
"#{name} ( #{op_string} #{inspect(input_names)} )",
122+
"#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )",
123123
"#{inspect({})}",
124124
"#{inspect(shape)}",
125125
render_options([]),
@@ -311,27 +311,6 @@ defmodule Axon.Display do
311311
end
312312
end
313313

314-
defp recur_axon_to_edges(
315-
%Axon.Node{id: id, op: :container, name: name_fn, parent: [parents]},
316-
nodes,
317-
templates,
318-
cache_counts_edgelist
319-
) do
320-
{node_inputs, {cache, op_counts, edgelist}} =
321-
deep_map_reduce(parents, cache_counts_edgelist, &axon_to_edges(&1, nodes, templates, &2))
322-
323-
name = name_fn.(:container, op_counts)
324-
node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
325-
to_node = %{axon: :axon, id: id, op: :container, name: name, shape: node_shape}
326-
327-
new_edgelist =
328-
deep_reduce(node_inputs, edgelist, fn from_node, acc ->
329-
[{from_node, to_node} | acc]
330-
end)
331-
332-
{to_node, {cache, op_counts, new_edgelist}}
333-
end
334-
335314
defp recur_axon_to_edges(
336315
%Axon.Node{id: id, op_name: op, name: name_fn, parent: parents},
337316
nodes,

0 commit comments

Comments
 (0)