Skip to content

Commit 5c78e13

Browse files
committed
Fix for non-tuple/vector inputs
1 parent 6a2d8de commit 5c78e13

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/layers/basic.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,11 +582,15 @@ function PairwiseFusion(connection; kw...)
582582
end
583583

584584
function (m::PairwiseFusion)(x::T) where {T}
585-
getinput(i) = T <: Union{Tuple, Vector} ? x[i] : x
586-
nx = length(x)
587585
nlayers = length(m.layers)
588-
if nx != nlayers
589-
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
586+
if T <: Union{Tuple, Vector}
587+
getinput(i) = x[i]
588+
nx = length(x)
589+
if nx != nlayers
590+
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
591+
end
592+
else
593+
getinput(i) = x
590594
end
591595
outputs = [m.layers[1](getinput(1))]
592596
for i in 2:nlayers

0 commit comments

Comments
 (0)