Skip to content

Commit 96a6448

Browse files
committed
Use generated function
1 parent 98e6e5f commit 96a6448

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

src/layers/basic.jl

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ end
6767

6868
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
6969
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
70-
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
70+
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
7171
function Base.show(io::IO, c::Chain)
7272
print(io, "Chain(")
7373
_show_layers(io, c.layers)
@@ -566,35 +566,37 @@ end
566566
567567
`PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568568
"""
569-
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
569+
struct PairwiseFusion{F, T <: NamedTuple}
570570
connection::F
571571
layers::T
572572
end
573573

574-
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
575-
function PairwiseFusion(connection; kw...)
576-
layers = NamedTuple(kw)
577-
if :layers in keys(layers) || :connection in keys(layers)
578-
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
579-
end
580-
isempty(layers) && return Parallel(connection, ())
581-
return PairwiseFusion(connection, layers)
574+
function PairwiseFusion(connection, layers...)
575+
names = ntuple(i -> Symbol("layer_$i"), length(layers))
576+
return PairwiseFusion(connection, NamedTuple{names}(layers))
582577
end
583578

584579
function (m::PairwiseFusion)(x::T) where {T}
585-
nlayers = length(m.layers)
586-
getinput(i) = T <: Tuple ? x[i] : x
587-
if T <: Tuple
588-
nx = length(x)
589-
if nx != nlayers
590-
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
591-
end
592-
end
593-
outputs = [m.layers[1](getinput(1))]
594-
for i in 2:nlayers
595-
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1])))
580+
lx = length(x)
581+
N = length(m.layers)
582+
if T <: Tuple && lx != N
583+
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
596584
end
597-
return outputs
585+
applypairwisefusion(m.layers, m.connection, x)
586+
end
587+
588+
@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T}
589+
N = length(names)
590+
y_symbols = [gensym() for _ in 1:(N + 1)]
591+
getinput(i) = T <: Tuple ? :(x[$i]) : :x
592+
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
593+
append!(calls,
594+
[:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
595+
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
596+
for i in 1:N - 1])
597+
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
598+
push!(calls, :(return $(y_symbols[N])))
599+
return Expr(:block, calls...)
598600
end
599601

600602
@functor PairwiseFusion

0 commit comments

Comments
 (0)