Skip to content

Commit e457ae5

Browse files
fix tests
1 parent 0273b25 commit e457ae5

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

src/layers/basic.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,9 @@ applylayer(l, g::GNNGraph, x) = l(x)
6464
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
6565

6666
# Handle Flux.Parallel
67-
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers)
68-
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs)
67+
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
68+
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs)
6969
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)
70-
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers)
71-
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs)
72-
7370

7471

7572
applychain(::Tuple{}, g::GNNGraph, x) = x

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ end
641641

642642

643643
function Base.show(io::IO, l::ResGatedGraphConv)
644-
out_channel, in_channel = size(l.weight)
644+
out_channel, in_channel = size(l.A)
645645
print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel)
646646
l.σ == identity || print(io, ", ", l.σ)
647647
print(io, ")")

test/layers/basic.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,21 @@
1515

1616
testmode!(gnn)
1717

18-
test_layer(gnn, g, rtol=1e-5) # exclude BN buffers
18+
test_layer(gnn, g, rtol=1e-5)
1919

2020

2121
@testset "Parallel" begin
2222
AddResidual(l) = Parallel(+, identity, l)
2323

24-
gnn = GNNChain(AddResidual(ResGatedGraphConv(din => d, tanh)),
24+
gnn = GNNChain(ResGatedGraphConv(din => d, tanh),
2525
BatchNorm(d),
2626
AddResidual(ResGatedGraphConv(d => d, tanh)),
2727
BatchNorm(d),
2828
Dense(d, dout))
2929

3030
testmode!(gnn)
3131

32-
test_layer(gnn, g, rtol=1e-5, verbose=true,
33-
exclude_grad_fields=[, :σ², ]) # exclude BN buffers
32+
test_layer(gnn, g, rtol=1e-5)
3433
end
3534
end
3635
end

0 commit comments

Comments
 (0)