Skip to content

Commit 0b01b77

Browse files
authored
Merge pull request #1983 from theabhirath/pairwise-fusion-2
`PairwiseFusion` layer, take 2
2 parents f86b356 + d0f0a29 commit 0b01b77

File tree

5 files changed

+147
-14
lines changed

5 files changed

+147
-14
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Flux Release Notes
22

3+
## v0.13.4
4+
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)
5+
36
## v0.13
47
* After a deprecations cycle, the datasets in `Flux.Data` have
58
been removed in favour of MLDatasets.jl.

src/Flux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export gradient
1616
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
1717
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")
1818

19-
export Chain, Dense, Maxout, SkipConnection, Parallel,
19+
export Chain, Dense, Maxout, SkipConnection, Parallel, PairwiseFusion,
2020
RNN, LSTM, GRU, GRUv3,
2121
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
2222
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,

src/layers/basic.jl

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838

3939
Chain(xs...) = Chain(xs)
4040
function Chain(; kw...)
41-
:layers in Base.keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
41+
:layers in keys(kw) && throw(ArgumentError("a Chain cannot have a named layer called `layers`"))
4242
isempty(kw) && return Chain(())
4343
Chain(values(kw))
4444
end
@@ -67,7 +67,7 @@ end
6767

6868
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
6969
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
70-
Chain(NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i]))
70+
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
7171
function Base.show(io::IO, c::Chain)
7272
print(io, "Chain(")
7373
_show_layers(io, c.layers)
@@ -487,7 +487,7 @@ end
487487
Parallel(connection, layers...) = Parallel(connection, layers)
488488
function Parallel(connection; kw...)
489489
layers = NamedTuple(kw)
490-
if :layers in Base.keys(layers) || :connection in Base.keys(layers)
490+
if :layers in keys(layers) || :connection in keys(layers)
491491
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
492492
end
493493
isempty(layers) && return Parallel(connection, ())
@@ -498,28 +498,138 @@ end
498498

499499
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
500500
(m::Parallel)(xs::Tuple) = m(xs...)
501-
function (m::Parallel)(xs...)
502-
nl = length(m.layers)
503-
nx = length(xs)
504-
if nl != nx
501+
502+
function _parallel_check(layers, xs)
503+
nl = length(layers)
504+
nx = length(xs)
505+
if (nl != nx)
505506
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
506507
end
508+
end
509+
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)
510+
511+
function (m::Parallel)(xs...)
512+
_parallel_check(m.layers, xs)
507513
m.connection(map(|>, xs, Tuple(m.layers))...)
508514
end
509515

510516
Base.getindex(m::Parallel, i) = m.layers[i]
511517
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
512518
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
513-
Parallel(m.connection, NamedTuple{Base.keys(m)[i]}(Tuple(m.layers)[i]))
519+
Parallel(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
514520

515-
Base.keys(m::Parallel) = Base.keys(getfield(m, :layers))
521+
Base.keys(m::Parallel) = keys(getfield(m, :layers))
516522

517523
function Base.show(io::IO, m::Parallel)
518524
print(io, "Parallel(", m.connection, ", ")
519525
_show_layers(io, m.layers)
520526
print(io, ")")
521527
end
522528

529+
"""
530+
PairwiseFusion(connection, layers...)
531+
532+
## Arguments
533+
534+
- `connection`: A function taking 2 inputs and combining them into a single output
535+
- `layers`: The layers whose outputs are combined
536+
537+
## Inputs
538+
539+
This layer behaves differently based on input type:
540+
541+
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
542+
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
543+
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
544+
may be drawn as:
545+
```
546+
x1 → layer1 → y1 ↘
547+
connection → layer2 → y2 ↘
548+
x2 ↗ connection → layer3 → y3
549+
x3 ↗
550+
```
551+
... or written as:
552+
```julia
553+
y1 = layer1(x1)
554+
y2 = layer2(connection(x2, y1))
555+
y3 = layer3(connection(x3, y2))
556+
```
557+
558+
2. With just one input, each layer receives the same `x` combined with the previous output.
559+
Thus `y = PairwiseFusion(connection, layers...)(x)` obeys:
560+
561+
```julia
562+
y[1] == layers[1](x)
563+
for i in 2:length(layers)
564+
y[i] == connection(x, layers[i](y[i-1]))
565+
end
566+
```
567+
568+
## Returns
569+
570+
A tuple of length N with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
571+
"""
572+
struct PairwiseFusion{F, T<:Union{Tuple, NamedTuple}}
573+
connection::F
574+
layers::T
575+
end
576+
577+
PairwiseFusion(connection, layers...) = PairwiseFusion(connection, layers)
578+
function PairwiseFusion(connection; kw...)
579+
layers = NamedTuple(kw)
580+
if :layers in keys(layers) || :connection in keys(layers)
581+
throw(ArgumentError("a PairwiseFusion layer cannot have a named sub-layer called `connection` or `layers`"))
582+
end
583+
isempty(layers) && return PairwiseFusion(connection, ())
584+
PairwiseFusion(connection, layers)
585+
end
586+
587+
function _pairwise_check(x, layers, T)
588+
lx = length(x)
589+
N = length(layers)
590+
if T <: Tuple && lx != N
591+
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
592+
end
593+
end
594+
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)
595+
596+
function (m::PairwiseFusion)(x::T) where {T}
597+
_pairwise_check(x, m.layers, T)
598+
applypairwisefusion(m.layers, m.connection, x)
599+
end
600+
(m::PairwiseFusion)(xs...) = m(xs)
601+
602+
@generated function applypairwisefusion(layers::Tuple{Vararg{<:Any,N}}, connection, x::T) where {N, T}
603+
y_symbols = [gensym() for _ in 1:(N + 1)]
604+
getinput(i) = T <: Tuple ? :(x[$i]) : :x
605+
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
606+
for i in 1:N - 1
607+
push!(calls, quote
608+
$(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]))
609+
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))
610+
end)
611+
end
612+
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
613+
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
614+
return Expr(:block, calls...)
615+
end
616+
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)
617+
618+
@functor PairwiseFusion
619+
620+
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
621+
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
622+
Base.getindex(m::PairwiseFusion{<:Any, <:NamedTuple}, i::AbstractVector) =
623+
PairwiseFusion(m.connection, NamedTuple{keys(m)[i]}(Tuple(m.layers)[i]))
624+
625+
Base.keys(m::PairwiseFusion) = keys(getfield(m, :layers))
626+
627+
function Base.show(io::IO, m::PairwiseFusion)
628+
print(io, "PairwiseFusion(", m.connection, ", ")
629+
_show_layers(io, m.layers)
630+
print(io, ")")
631+
end
632+
523633
"""
524634
Embedding(in => out; init=randn)
525635
@@ -556,7 +666,7 @@ end
556666
@functor Embedding
557667

558668
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
559-
669+
560670
(m::Embedding)(x::Integer) = m.weight[:, x]
561671
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
562672
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)
@@ -565,7 +675,7 @@ function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T
565675
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
566676
return m(onecold(x))
567677
end
568-
678+
569679
function Base.show(io::IO, m::Embedding)
570680
print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")")
571681
end

src/layers/show.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
for T in [
3-
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout # container types
3+
:Chain, :Parallel, :SkipConnection, :Recur, :Maxout, :PairwiseFusion # container types
44
]
55
@eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
66
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
@@ -25,7 +25,7 @@ function _big_show(io::IO, obj, indent::Int=0, name=nothing)
2525
for k in Base.keys(obj)
2626
_big_show(io, obj[k], indent+2, k)
2727
end
28-
elseif obj isa Parallel{<:Any, <:NamedTuple}
28+
elseif obj isa Parallel{<:Any, <:NamedTuple} || obj isa PairwiseFusion{<:Any, <:NamedTuple}
2929
_big_show(io, obj.connection, indent+2)
3030
for k in Base.keys(obj)
3131
_big_show(io, obj[k], indent+2, k)
@@ -53,6 +53,7 @@ _show_children(x) = trainable(x) # except for layers which hide their Tuple:
5353
_show_children(c::Chain) = c.layers
5454
_show_children(m::Maxout) = m.layers
5555
_show_children(p::Parallel) = (p.connection, p.layers...)
56+
_show_children(f::PairwiseFusion) = (f.connection, f.layers...)
5657

5758
for T in [
5859
:Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,

test/layers/basic.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,22 @@ end
350350
@test Flux.destructure(m1)[2](z1)[1].weight Flux.destructure(m1v)[2](z1)[1].weight
351351
# Note that Flux.destructure(m1v)[2](z) has a Chain{Tuple}, as does m1v[1:2]
352352
end
353+
354+
@testset "PairwiseFusion" begin
355+
x = (rand(1, 10), rand(30, 10))
356+
layer = PairwiseFusion(+, Dense(1, 30), Dense(30, 10))
357+
y = layer(x)
358+
@test length(y) == 2
359+
@test size(y[1]) == (30, 10)
360+
@test size(y[2]) == (10, 10)
361+
362+
x = rand(1, 10)
363+
layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1))
364+
y = layer(x)
365+
@test length(y) == 2
366+
@test size(y[1]) == (10, 10)
367+
@test size(y[2]) == (1, 10)
368+
369+
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(2, 10, 20) == (3, [5, 12], [125, 1728, 8000])
370+
@test PairwiseFusion(vcat, x->x.+1, x->x.+2, x->x.^3)(7) == (8, [10, 9], [1000, 729, 343])
371+
end

0 commit comments

Comments
 (0)