Skip to content

Commit d0f0a29

Browse files
committed
Allow N inputs
Add tests with `vcat`
1 parent 78157e3 commit d0f0a29

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/layers/basic.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ end
538538
539539
This layer behaves differently based on input type:
540540
541-
1. If input `x` is a tuple of length `N`, matching the number of `layers`,
541+
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
542542
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
543543
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
544544
may be drawn as:
@@ -567,7 +567,7 @@ end
567567
568568
## Returns
569569
570-
A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
570+
A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
571571
"""
572572
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
573573
connection::F
@@ -597,6 +597,7 @@ function (m::PairwiseFusion)(x::T) where {T}
597597
_pairwise_check(x, m.layers, T)
598598
applypairwisefusion(m.layers, m.connection, x)
599599
end
600+
(m::PairwiseFusion)(xs...) = m(xs)
600601

601602
@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
602603
y_symbols = [gensym() for _ in 1:(N + 1)]

test/layers/basic.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,7 @@ end
365365
@test length(y) == 2
366366
@test size(y[1]) == (10, 10)
367367
@test size(y[2]) == (1, 10)
368+
369+
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000])
370+
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343])
368371
end

0 commit comments

Comments
 (0)