Skip to content

Commit 6a2d8de

Browse files
committed
PairwiseFusion layer
1 parent b6b3569 commit 6a2d8de

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export gradient
1616
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
1717
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")
1818

19-
export Chain, Dense, Maxout, SkipConnection, Parallel,
19+
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
2020
RNN, LSTM, GRU, GRUv3,
2121
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
2222
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,

src/layers/basic.jl

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ end
487487
Parallel(connection, layers...) = Parallel(connection, layers)
488488
function Parallel(connection; kw...)
489489
layers = NamedTuple(kw)
490-
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
490+
if :layers in keys(layers) || :connection in keys(layers)
491491
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
492492
end
493493
isempty(layers) && return Parallel(connection, ())
@@ -510,16 +510,100 @@ end
510510
Base.getindex(m::Parallel, i) = m.layers[i]
511511
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
512512
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
513-
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
513+
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
514514

515-
Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
515+
Base.keys(m::Parallel) = keys(getfield(m, :layers))
516516

517517
function Base.show(io::IO, m::Parallel)
518518
print(io, "Parallel(", m.connection, ", ")
519519
_show_layers(io, m.layers)
520520
print(io, ")")
521521
end
522522

523+
"""
524+
PairwiseFusion(connection, layers...)
525+
526+
```
527+
x1 --> layer1 --> y1
528+
|
529+
|--> connection --> layer2 --> y2
530+
| |
531+
x2 |--> connection --> layer3 --> y3
532+
| |
533+
x3 |--> connection --> y4
534+
|
535+
x4
536+
```
537+
538+
## Arguments
539+
540+
- `connection`: Takes 2 inputs and combines them
541+
- `layers`: The layers whose outputs are combined
542+
543+
## Inputs
544+
545+
This layer behaves differently based on input type:
546+
547+
1. Input `x` is a tuple/vector of length `N`. Then `layers` must be a tuple of length `N`. The computation is as follows:
548+
549+
```julia
550+
y = x[1]
551+
for i in 1:N
552+
y = connection(x[i], layers[i](y))
553+
end
554+
```
555+
556+
2. Any other kind of input:
557+
558+
```julia
559+
y = x
560+
for i in 1:N
561+
y = connection(x, layers[i](y))
562+
end
563+
```
564+
565+
## Returns
566+
567+
`PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568+
"""
569+
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
570+
connection::F
571+
layers::T
572+
end
573+
574+
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
575+
function PairwiseFusion(connection; kw...)
576+
layers = NamedTuple(kw)
577+
if :layers in keys(layers) || :connection in keys(layers)
578+
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
579+
end
580+
isempty(layers) && return Parallel(connection, ())
581+
return PairwiseFusion(connection, layers)
582+
end
583+
584+
function (m::PairwiseFusion)(x::T) where {T}
585+
getinput(i) = T <: Union{Tuple, Vector} ? x[i] : x
586+
nx = length(x)
587+
nlayers = length(m.layers)
588+
if nx != nlayers
589+
throw(ArgumentError("PairwiseFusion with $nlayers layers takes $nlayers inputs, but got $nx inputs"))
590+
end
591+
outputs = [m.layers[1](getinput(1))]
592+
for i in 2:nlayers
593+
push!(outputs, m.layers[i](m.connection(getinput(i), outputs[i - 1])))
594+
end
595+
return outputs
596+
end
597+
598+
@functor PairwiseFusion
599+
600+
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
601+
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
602+
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
603+
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
604+
605+
Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))
606+
523607
"""
524608
Embedding(in => out; init=randn)
525609

0 commit comments

Comments
 (0)