Skip to content

Commit b720b05

Browse files
fix GlobalPooling with graph only input
1 parent afad470 commit b720b05

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

src/layers/pool.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
4040
return reduce_nodes(l.aggr, g, x)
4141
end
4242

43+
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
44+
4345
"""
4446
TopKPool(adj, k, in_channel)
4547

src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
5252
end
5353

5454
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
55-
5655
if duplicate_if_needed
5756
# Used to copy edge features on reverse edges
5857
@assert all(s -> s == 0 || s == n || s == n÷2, sz)

test/layers/pool.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
@testset "GlobalPool" begin
33
n = 10
44
X = rand(16, n)
5-
g = GNNGraph(random_regular_graph(n, 4))
5+
g = GNNGraph(random_regular_graph(n, 4), ndata=X)
66
p = GlobalPool(+)
7-
@test p(g, X) NNlib.scatter(+, X, ones(Int, n))
7+
y = p(g, X)
8+
@test y NNlib.scatter(+, X, ones(Int, n))
9+
test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph)
810
end
911

1012
@testset "TopKPool" begin

test/test_utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
1818
verbose = false,
1919
test_gpu = TEST_GPU,
2020
outsize = nothing,
21+
outtype = :node,
2122
)
2223

2324
# TODO these give errors, probably some bugs in ChainRulesTestUtils
@@ -57,8 +58,15 @@ function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
5758
@test ycoo y
5859

5960
g′ = f(l, g)
60-
@test g′.ndata.x y
61-
61+
if outtype == :node
62+
@test g′.ndata.x y
63+
elseif outtype == :edge
64+
@test g′.edata.e y
65+
elseif outtype == :graph
66+
@test g′.gdata.u y
67+
else
68+
@error "wrong outtype $outtype"
69+
end
6270
if test_gpu
6371
ygpu = f(lgpu, ggpu, xgpu)
6472
@test ygpu isa CuArray

0 commit comments

Comments
 (0)