@@ -544,7 +544,7 @@ x1 --> layer1 --> y1
544
544
545
545
This layer behaves differently based on input type:
546
546
547
- 1. Input `x` is a tuple/vector of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
547
+ 1. Input `x` is a tuple of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
548
548
549
549
```julia
550
550
y = x[1]
@@ -583,14 +583,12 @@ end
583
583
584
584
function (m:: PairwiseFusion )(x:: T ) where {T}
585
585
nlayers = length (m. layers)
586
- if T <: Union{ Tuple, Vector}
587
- getinput (i) = x[i]
586
+ getinput (i) = T <: Tuple ? x[i] : x
587
+ if T <: Tuple
588
588
nx = length (x)
589
589
if nx != nlayers
590
590
throw (ArgumentError (" PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs" ))
591
591
end
592
- else
593
- getinput (i) = x
594
592
end
595
593
outputs = [m. layers[1 ](getinput (1 ))]
596
594
for i in 2 : nlayers
644
642
@functor Embedding
645
643
646
644
Embedding ((in, out):: Pair{<:Integer, <:Integer} ; init = randn32) = Embedding (init (out, in))
647
-
645
+
648
646
(m:: Embedding )(x:: Integer ) = m. weight[:, x]
649
647
(m:: Embedding )(x:: AbstractVector ) = NNlib. gather (m. weight, x)
650
648
(m:: Embedding )(x:: AbstractArray ) = reshape (m (vec (x)), :, size (x)... )
@@ -653,7 +651,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
653
651
size (m. weight, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (m. weight, 2 )) != $L " ))
654
652
return m (onecold (x))
655
653
end
656
-
654
+
657
655
function Base. show (io:: IO , m:: Embedding )
658
656
print (io, " Embedding(" , size (m. weight, 2 ), " => " , size (m. weight, 1 ), " )" )
659
657
end
0 commit comments