@@ -109,7 +109,7 @@ defmodule Axon.Display do
109
109
% Axon.Node {
110
110
id: id ,
111
111
op_name: :container ,
112
- parent: [ parents ] ,
112
+ parent: [ _ | _ ] = parents ,
113
113
name: name_fn
114
114
} ,
115
115
nodes ,
@@ -221,6 +221,13 @@ defmodule Axon.Display do
221
221
"#{ type } #{ shape } "
222
222
end
223
223
224
+ defp render_output_shape ( shapes ) when is_tuple ( shapes ) do
225
+ shapes
226
+ |> Tuple . to_list ( )
227
+ |> Enum . map ( & render_output_shape ( & 1 ) )
228
+ |> Enum . join ( ", " )
229
+ end
230
+
224
231
defp type_str ( { type , size } ) , do: "#{ Atom . to_string ( type ) } #{ size } "
225
232
226
233
defp render_options ( opts ) do
@@ -347,7 +354,9 @@ defmodule Axon.Display do
347
354
348
355
name = name_fn . ( op , op_counts )
349
356
node_shape = Axon . get_output_shape ( % Axon { output: id , nodes: nodes } , templates )
350
- to_node = % { axon: :axon , id: id , op: op , name: name , shape: node_shape }
357
+ shape = expand_output_shape ( node_shape )
358
+
359
+ to_node = % { axon: :axon , id: id , op: op , name: name , shape: shape }
351
360
352
361
new_edgelist =
353
362
Enum . reduce ( node_inputs , edgelist , fn from_node , acc ->
@@ -357,6 +366,15 @@ defmodule Axon.Display do
357
366
{ to_node , { cache , op_counts , new_edgelist } }
358
367
end
359
368
369
+ defp expand_output_shape ( % Nx.Tensor { } = tensor ) , do: Nx . shape ( tensor )
370
+
371
+ defp expand_output_shape ( shapes ) when is_tuple ( shapes ) do
372
+ shapes
373
+ |> Tuple . to_list ( )
374
+ |> Enum . map ( & expand_output_shape / 1 )
375
+ |> List . to_tuple ( )
376
+ end
377
+
360
378
defp generate_mermaid_node_entry ( % { id: id , op: :input , name: name , shape: shape } ) do
361
379
~s' #{ id } [/"#{ name } (:input) #{ inspect ( shape ) } "/]'
362
380
end
0 commit comments