|
67 | 67 |
|
68 | 68 | Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
|
69 | 69 | Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
|
70 |
| - Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])) |
| 70 | + Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) |
71 | 71 | function Base.show(io::IO, c::Chain)
|
72 | 72 | print(io, "Chain(")
|
73 | 73 | _show_layers(io, c.layers)
|
@@ -566,35 +566,37 @@ end
|
566 | 566 |
|
567 | 567 | `PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
|
568 | 568 | """
|
569 |
| -struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}} |
| 569 | +struct PairwiseFusion{F, T <: NamedTuple} |
570 | 570 | connection::F
|
571 | 571 | layers::T
|
572 | 572 | end
|
573 | 573 |
|
574 |
| -PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers) |
575 |
| -function PairwiseFusion(connection; kw...) |
576 |
| - layers = NamedTuple(kw) |
577 |
| - if :layers in keys(layers) || :connection in keys(layers) |
578 |
| - throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`")) |
579 |
| - end |
580 |
| - isempty(layers) && return Parallel(connection, ()) |
581 |
| - return PairwiseFusion(connection, layers) |
| 574 | +function PairwiseFusion(connection, layers...) |
| 575 | + names = ntuple(i -> Symbol("layer_$i"), length(layers)) |
| 576 | + return PairwiseFusion(connection, NamedTuple{names}(layers)) |
582 | 577 | end
|
583 | 578 |
|
584 | 579 | function (m::PairwiseFusion)(x::T) where {T}
|
585 |
| - nlayers = length(m.layers) |
586 |
| - getinput(i) = T <: Tuple ? x[i] : x |
587 |
| - if T <: Tuple |
588 |
| - nx = length(x) |
589 |
| - if nx != nlayers |
590 |
| - throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs")) |
591 |
| - end |
592 |
| - end |
593 |
| - outputs = [m.layers[1](getinput(1))] |
594 |
| - for i in 2:nlayers |
595 |
| - push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1]))) |
| 580 | + lx = length(x) |
| 581 | + N = length(m.layers) |
| 582 | + if T <: Tuple && lx != N |
| 583 | + throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs")) |
596 | 584 | end
|
597 |
| - return outputs |
| 585 | + applypairwisefusion(m.layers, m.connection, x) |
| 586 | +end |
| 587 | + |
| 588 | +@generated function applypairwisefusion(layers::NamedTuple{names}, connection, x::T) where {names, T} |
| 589 | + N = length(names) |
| 590 | + y_symbols = [gensym() for _ in 1:(N + 1)] |
| 591 | + getinput(i) = T <: Tuple ? :(x[$i]) : :x |
| 592 | + calls = [:($(y_symbols[N + 1]) = $(getinput(1)))] |
| 593 | + append!(calls, |
| 594 | + [:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1])); |
| 595 | + $(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))) |
| 596 | + for i in 1:N - 1]) |
| 597 | + push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1])))) |
| 598 | + push!(calls, :(return $(y_symbols[N]))) |
| 599 | + return Expr(:block, calls...) |
598 | 600 | end
|
599 | 601 |
|
600 | 602 | @functor PairwiseFusion
|
|
0 commit comments