Skip to content

Commit 15ace5c

Browse files
authored
fix(display): handle multi inputs and outputs (#605)
1 parent c61077c commit 15ace5c

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

lib/axon/display.ex

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ defmodule Axon.Display do
109109
%Axon.Node{
110110
id: id,
111111
op_name: :container,
112-
parent: [parents],
112+
parent: [_ | _] = parents,
113113
name: name_fn
114114
},
115115
nodes,
@@ -221,6 +221,13 @@ defmodule Axon.Display do
221221
"#{type}#{shape}"
222222
end
223223

224+
defp render_output_shape(shapes) when is_tuple(shapes) do
225+
shapes
226+
|> Tuple.to_list()
227+
|> Enum.map(&render_output_shape(&1))
228+
|> Enum.join(", ")
229+
end
230+
224231
defp type_str({type, size}), do: "#{Atom.to_string(type)}#{size}"
225232

226233
defp render_options(opts) do
@@ -347,7 +354,9 @@ defmodule Axon.Display do
347354

348355
name = name_fn.(op, op_counts)
349356
node_shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates)
350-
to_node = %{axon: :axon, id: id, op: op, name: name, shape: node_shape}
357+
shape = expand_output_shape(node_shape)
358+
359+
to_node = %{axon: :axon, id: id, op: op, name: name, shape: shape}
351360

352361
new_edgelist =
353362
Enum.reduce(node_inputs, edgelist, fn from_node, acc ->
@@ -357,6 +366,15 @@ defmodule Axon.Display do
357366
{to_node, {cache, op_counts, new_edgelist}}
358367
end
359368

369+
defp expand_output_shape(%Nx.Tensor{} = tensor), do: Nx.shape(tensor)
370+
371+
defp expand_output_shape(shapes) when is_tuple(shapes) do
372+
shapes
373+
|> Tuple.to_list()
374+
|> Enum.map(&expand_output_shape/1)
375+
|> List.to_tuple()
376+
end
377+
360378
defp generate_mermaid_node_entry(%{id: id, op: :input, name: name, shape: shape}) do
361379
~s'#{id}[/"#{name} (:input) #{inspect(shape)}"/]'
362380
end

0 commit comments

Comments
 (0)