Skip to content

Commit 617ae2a

Browse files
committed
Return all inputs, not just the final one
Add pretty printing
1 parent 96a6448 commit 617ae2a

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/layers/basic.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,19 +595,24 @@ end
595595
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
596596
for i in 1:N - 1])
597597
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
598-
push!(calls, :(return $(y_symbols[N])))
598+
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
599599
return Expr(:block, calls...)
600600
end
601601

602602
@functor PairwiseFusion
603603

604604
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
605-
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
606605
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
607606
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
608607

609608
Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))
610609

610+
function Base.show(io::IO, m::PairwiseFusion)
611+
print(io, "PairwiseFusion(", m.connection, ", ")
612+
_show_layers(io, m.layers)
613+
print(io, ")")
614+
end
615+
611616
"""
612617
Embedding(in => out; init=randn)
613618

src/layers/show.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
for T in [
3-
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
3+
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
44
]
55
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
66
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
@@ -53,6 +53,7 @@ _show_children(x) = trainable(x) # except for layers which hide their Tuple:
5353
_show_children(c::Chain) = c.layers
5454
_show_children(m::Maxout) = m.layers
5555
_show_children(p::Parallel) = (p.connection, p.layers...)
56+
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)
5657

5758
for T in [
5859
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,

0 commit comments

Comments
 (0)