|
1 | 1 | @testset "pool" begin
|
2 | 2 | @testset "GlobalPool" begin
|
| 3 | + p = GlobalPool(+) |
3 | 4 | n = 10
|
4 |
| - X = rand(Float32, 16, n) |
| 5 | + chin = 6 |
| 6 | + X = rand(Float32, 6, n) |
5 | 7 | g = GNNGraph(random_regular_graph(n, 4), ndata=X)
|
6 |
| - p = GlobalPool(+) |
7 |
| - y = p(g, X) |
8 |
| - @test y ≈ NNlib.scatter(+, X, ones(Int, n)) |
| 8 | + u = p(g, X) |
| 9 | + @test u ≈ sum(X, dims=2) |
| 10 | + |
| 11 | + ng = 3 |
| 12 | + g = Flux.batch([GNNGraph(random_regular_graph(n, 4), |
| 13 | + ndata=rand(Float32, chin, n)) |
| 14 | + for i=1:ng]) |
| 15 | + u = p(g, g.ndata.x) |
| 16 | + @test size(u) == (chin, ng) |
| 17 | + @test u[:,[1]] ≈ sum(g.ndata.x[:,1:n], dims=2) |
| 18 | + @test p(g).gdata.u == u |
| 19 | + |
9 | 20 | test_layer(p, g, rtol=1e-5, exclude_grad_fields = [:aggr], outtype=:graph)
|
10 | 21 | end
|
11 | 22 |
|
| 23 | + @testset "GlobalAttentionPool" begin |
| 24 | + n = 10 |
| 25 | + chin = 16 |
| 26 | + X = rand(Float32, chin, n) |
| 27 | + g = GNNGraph(random_regular_graph(n, 4), ndata=X) |
| 28 | + fgate = Dense(chin, 1, sigmoid) |
| 29 | + p = GlobalAttentionPool(fgate) |
| 30 | + y = p(g, X) |
| 31 | + test_layer(p, g, rtol=1e-5, outtype=:graph) |
| 32 | + end |
| 33 | + |
| 34 | + |
12 | 35 | @testset "TopKPool" begin
|
13 | 36 | N = 10
|
14 | 37 | k, in_channel = 4, 7
|
|
0 commit comments