Skip to content

Commit e842e7b

Browse files
committed
Refine layer a bit; add one more test
1 parent 1180a62 commit e842e7b

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

src/layers/basic.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ end
564564
565565
## Returns
566566
567-
`PairwiseFusion` returns a tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
567+
A tuple of length `N` with the output of each fusion ((`y1`, `y2`, ..., `yN`) in the example above).
568568
"""
569569
struct PairwiseFusion{F, T <: NamedTuple}
570570
connection::F
@@ -576,12 +576,17 @@ function PairwiseFusion(connection, layers...)
576576
return PairwiseFusion(connection, NamedTuple{names}(layers))
577577
end
578578

579-
function (m::PairwiseFusion)(x::T) where {T}
580-
lx = length(x)
581-
N = length(m.layers)
579+
function _pairwise_check(lx, N, T)
582580
if T <: Tuple && lx != N
583581
throw(ArgumentError("PairwiseFusion with $N sub-layers can take one input or $N inputs, but got $lx inputs"))
584582
end
583+
end
584+
ChainRulesCore.@non_differentiable _pairwise_check(lx, N, T)
585+
586+
function (m::PairwiseFusion)(x::T) where {T}
587+
lx = length(x)
588+
N = length(m.layers)
589+
_pairwise_check(lx, N, T)
585590
applypairwisefusion(m.layers, m.connection, x)
586591
end
587592

@@ -590,10 +595,12 @@ end
590595
y_symbols = [gensym() for _ in 1:(N + 1)]
591596
getinput(i) = T <: Tuple ? :(x[$i]) : :x
592597
calls = [:($(y_symbols[N + 1]) = $(getinput(1)))]
593-
append!(calls,
594-
[:($(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]));
595-
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1))))
596-
for i in 1:N - 1])
598+
for i in 1:N - 1
599+
push!(calls, quote
600+
$(y_symbols[i]) = layers[$i]($(y_symbols[N + 1]))
601+
$(y_symbols[N + 1]) = connection($(y_symbols[i]), $(getinput(i + 1)))
602+
end)
603+
end
597604
push!(calls, :($(y_symbols[N]) = layers[$N]($(y_symbols[N + 1]))))
598605
push!(calls, :(return tuple($(Tuple(y_symbols[1:N])...))))
599606
return Expr(:block, calls...)

test/layers/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,4 +358,11 @@ end
358358
@test length(y) == 2
359359
@test size(y[1]) == (30, 10)
360360
@test size(y[2]) == (10, 10)
361+
362+
x = rand(1, 10)
363+
layer = PairwiseFusion(.+, Dense(1, 10), Dense(10, 1))
364+
y = layer(x)
365+
@test length(y) == 2
366+
@test size(y[1]) == (10, 10)
367+
@test size(y[2]) == (1, 10)
361368
end

0 commit comments

Comments
 (0)