Skip to content

Commit 98e6e5f

Browse files
committed
Only tuples for multiple inputs to allow for 1D input
1 parent 5c78e13 commit 98e6e5f

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/layers/basic.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ x1 --> layer1 --> y1
544544
545545
This layer behaves differently based on input type:
546546
547-
1. Input `x` is a tuple/vector of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
547+
1. Input `x` is a tuple of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
548548
549549
```julia
550550
y = x[1]
@@ -583,14 +583,12 @@ end
583583

584584
function (m::PairwiseFusion)(x::T) where {T}
585585
nlayers = length(m.layers)
586-
if T <: Union{Tuple, Vector}
587-
getinput(i) = x[i]
586+
getinput(i) = T <: Tuple ? x[i] : x
587+
if T <: Tuple
588588
nx = length(x)
589589
if nx != nlayers
590590
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
591591
end
592-
else
593-
getinput(i) = x
594592
end
595593
outputs = [m.layers[1](getinput(1))]
596594
for i in 2:nlayers
@@ -644,7 +642,7 @@ end
644642
@functor Embedding
645643

646644
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
647-
645+
648646
(m::Embedding)(x::Integer) = m.weight[:, x]
649647
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
650648
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
@@ -653,7 +651,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
653651
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
654652
return m(onecold(x))
655653
end
656-
654+
657655
function Base.show(io::IO, m::Embedding)
658656
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
659657
end

0 commit comments

Comments
 (0)