Skip to content

Commit 11515eb

Browse files
tests for GNNlib (#466)
1 parent 9338ed7 commit 11515eb

File tree

15 files changed

+293
-254
lines changed

15 files changed

+293
-254
lines changed

.github/workflows/test_GNNlib.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: GNNlib
2+
on:
3+
pull_request:
4+
branches:
5+
- master
6+
push:
7+
branches:
8+
- master
9+
jobs:
10+
test:
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
version:
17+
- '1.10' # Replace this with the minimum Julia version that your package supports.
18+
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
19+
# - 'pre'
20+
os:
21+
- ubuntu-latest
22+
arch:
23+
- x64
24+
25+
steps:
26+
- uses: actions/checkout@v4
27+
- uses: julia-actions/setup-julia@v2
28+
with:
29+
version: ${{ matrix.version }}
30+
arch: ${{ matrix.arch }}
31+
- uses: julia-actions/cache@v2
32+
- uses: julia-actions/julia-buildpkg@v1
33+
- name: Install Julia dependencies and run tests
34+
shell: julia --project=monorepo {0}
35+
run: |
36+
using Pkg
37+
# dev mono repo versions
38+
pkg"registry up"
39+
Pkg.update()
40+
pkg"dev ./GNNGraphs ./GNNlib"
41+
Pkg.test("GNNlib"; coverage=true)
42+
- uses: julia-actions/julia-processcoverage@v1
43+
with:
44+
# directories: ./GNNlib/src, ./GNNlib/ext
45+
directories: ./GNNlib/src
46+
- uses: codecov/codecov-action@v4
47+
with:
48+
files: lcov.info

GNNlib/Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2020
GNNlibCUDAExt = "CUDA"
2121

2222
[compat]
23-
ChainRulesCore = "1.24"
2423
CUDA = "4, 5"
24+
ChainRulesCore = "1.24"
2525
DataStructures = "0.18"
2626
GNNGraphs = "1.0"
2727
LinearAlgebra = "1"
@@ -32,7 +32,10 @@ Statistics = "1"
3232
julia = "1.10"
3333

3434
[extras]
35+
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
36+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
37+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3538
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3639

3740
[targets]
38-
test = ["Test"]
41+
test = ["Test", "ReTestItems", "Reexport", "SparseArrays"]

GNNlib/src/layers/pool.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11

22

3-
function global_pool(aggr, g::GNNGraph, x::AbstractArray)
4-
return reduce_nodes(aggr, g, x)
3+
function global_pool(l, g::GNNGraph, x::AbstractArray)
4+
return reduce_nodes(l.aggr, g, x)
55
end
66

7-
function global_attention_pool(fgate, ffeat, g::GNNGraph, x::AbstractArray)
8-
α = softmax_nodes(g, fgate(x))
9-
feats = α .* ffeat(x)
7+
function global_attention_pool(l, g::GNNGraph, x::AbstractArray)
8+
α = softmax_nodes(g, l.fgate(x))
9+
feats = α .* l.ffeat(x)
1010
u = reduce_nodes(+, g, feats)
1111
return u
1212
end
@@ -26,11 +26,11 @@ end
2626

2727
topk_index(y::Adjoint, k::Int) = topk_index(y', k)
2828

29-
function set2set_pool(lstm, num_iters, g::GNNGraph, x::AbstractMatrix)
29+
function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
3030
n_in = size(x, 1)
3131
qstar = zeros_like(x, (2*n_in, g.num_graphs))
32-
for t in 1:num_iters
33-
q = lstm(qstar) # [n_in, n_graphs]
32+
for t in 1:l.num_iters
33+
q = l.lstm(qstar) # [n_in, n_graphs]
3434
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
3535
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
3636
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]

GNNlib/test/msgpass_tests.jl

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
@testitem "msgpass" setup=[SharedTestSetup] begin
2+
#TODO test all graph types
3+
GRAPH_T = :coo
4+
in_channel = 10
5+
out_channel = 5
6+
num_V = 6
7+
num_E = 14
8+
T = Float32
9+
10+
adj = [0 1 0 0 0 0
11+
1 0 0 1 1 1
12+
0 0 0 0 0 1
13+
0 1 0 0 1 0
14+
0 1 0 1 0 1
15+
0 1 1 0 1 0]
16+
17+
X = rand(T, in_channel, num_V)
18+
E = rand(T, in_channel, num_E)
19+
20+
g = GNNGraph(adj, graph_type = GRAPH_T)
21+
22+
@testset "propagate" begin
23+
function message(xi, xj, e)
24+
@test xi === nothing
25+
@test e === nothing
26+
ones(T, out_channel, size(xj, 2))
27+
end
28+
29+
m = propagate(message, g, +, xj = X)
30+
31+
@test size(m) == (out_channel, num_V)
32+
33+
@testset "isolated nodes" begin
34+
x1 = rand(1, 6)
35+
g1 = GNNGraph(collect(1:5), collect(1:5), num_nodes = 6)
36+
y1 = propagate((xi, xj, e) -> xj, g, +, xj = x1)
37+
@test size(y1) == (1, 6)
38+
end
39+
end
40+
41+
@testset "apply_edges" begin
42+
m = apply_edges(g, e = E) do xi, xj, e
43+
@test xi === nothing
44+
@test xj === nothing
45+
ones(out_channel, size(e, 2))
46+
end
47+
48+
@test m == ones(out_channel, num_E)
49+
50+
# With NamedTuple input
51+
m = apply_edges(g, xj = (; a = X, b = 2X), e = E) do xi, xj, e
52+
@test xi === nothing
53+
@test xj.b == 2 * xj.a
54+
@test size(xj.a, 2) == size(xj.b, 2) == size(e, 2)
55+
ones(out_channel, size(e, 2))
56+
end
57+
58+
# NamedTuple output
59+
m = apply_edges(g, e = E) do xi, xj, e
60+
@test xi === nothing
61+
@test xj === nothing
62+
(; a = ones(out_channel, size(e, 2)))
63+
end
64+
65+
@test m.a == ones(out_channel, num_E)
66+
67+
@testset "sizecheck" begin
68+
x = rand(3, g.num_nodes - 1)
69+
@test_throws AssertionError apply_edges(copy_xj, g, xj = x)
70+
@test_throws AssertionError apply_edges(copy_xj, g, xi = x)
71+
72+
x = (a = rand(3, g.num_nodes), b = rand(3, g.num_nodes + 1))
73+
@test_throws AssertionError apply_edges(copy_xj, g, xj = x)
74+
@test_throws AssertionError apply_edges(copy_xj, g, xi = x)
75+
76+
e = rand(3, g.num_edges - 1)
77+
@test_throws AssertionError apply_edges(copy_xj, g, e = e)
78+
end
79+
end
80+
81+
@testset "copy_xj" begin
82+
n = 128
83+
A = sprand(n, n, 0.1)
84+
Adj = map(x -> x > 0 ? 1 : 0, A)
85+
X = rand(10, n)
86+
87+
g = GNNGraph(A, ndata = X, graph_type = GRAPH_T)
88+
89+
function spmm_copyxj_fused(g)
90+
propagate(copy_xj,
91+
g, +; xj = g.ndata.x)
92+
end
93+
94+
function spmm_copyxj_unfused(g)
95+
propagate((xi, xj, e) -> xj,
96+
g, +; xj = g.ndata.x)
97+
end
98+
99+
@test spmm_copyxj_unfused(g) X * Adj
100+
@test spmm_copyxj_fused(g) X * Adj
101+
end
102+
103+
@testset "e_mul_xj and w_mul_xj for weighted conv" begin
104+
n = 128
105+
A = sprand(n, n, 0.1)
106+
Adj = map(x -> x > 0 ? 1 : 0, A)
107+
X = rand(10, n)
108+
109+
g = GNNGraph(A, ndata = X, edata = A.nzval, graph_type = GRAPH_T)
110+
111+
function spmm_unfused(g)
112+
propagate((xi, xj, e) -> reshape(e, 1, :) .* xj,
113+
g, +; xj = g.ndata.x, e = g.edata.e)
114+
end
115+
function spmm_fused(g)
116+
propagate(e_mul_xj,
117+
g, +; xj = g.ndata.x, e = g.edata.e)
118+
end
119+
120+
function spmm_fused2(g)
121+
propagate(w_mul_xj,
122+
g, +; xj = g.ndata.x)
123+
end
124+
125+
@test spmm_unfused(g) X * A
126+
@test spmm_fused(g) X * A
127+
@test spmm_fused2(g) X * A
128+
end
129+
130+
@testset "aggregate_neighbors" begin
131+
@testset "sizecheck" begin
132+
m = rand(2, g.num_edges - 1)
133+
@test_throws AssertionError aggregate_neighbors(g, +, m)
134+
135+
m = (a = rand(2, g.num_edges + 1), b = nothing)
136+
@test_throws AssertionError aggregate_neighbors(g, +, m)
137+
end
138+
end
139+
140+
end

GNNlib/test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using GNNlib
2+
using Test
3+
using ReTestItems
4+
using Random, Statistics
5+
6+
runtests(GNNlib)

GNNlib/test/shared_testsetup.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
@testsetup module SharedTestSetup
2+
3+
import Reexport: @reexport
4+
5+
@reexport using GNNlib
6+
@reexport using GNNGraphs
7+
@reexport using NNlib
8+
@reexport using MLUtils
9+
@reexport using SparseArrays
10+
@reexport using Test, Random, Statistics
11+
12+
end

GNNlib/test/utils_tests.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
@testitem "utils" setup=[SharedTestSetup] begin
2+
# TODO test all graph types
3+
GRAPH_T = :coo
4+
De, Dx = 3, 2
5+
g = MLUtils.batch([rand_graph(10, 60, bidirected=true,
6+
ndata = rand(Dx, 10),
7+
edata = rand(De, 30),
8+
graph_type = GRAPH_T) for i in 1:5])
9+
x = g.ndata.x
10+
e = g.edata.e
11+
12+
@testset "reduce_nodes" begin
13+
r = reduce_nodes(mean, g, x)
14+
@test size(r) == (Dx, g.num_graphs)
15+
@test r[:, 2] mean(getgraph(g, 2).ndata.x, dims = 2)
16+
17+
r2 = reduce_nodes(mean, graph_indicator(g), x)
18+
@test r2 == r
19+
end
20+
21+
@testset "reduce_edges" begin
22+
r = reduce_edges(mean, g, e)
23+
@test size(r) == (De, g.num_graphs)
24+
@test r[:, 2] mean(getgraph(g, 2).edata.e, dims = 2)
25+
end
26+
27+
@testset "softmax_nodes" begin
28+
r = softmax_nodes(g, x)
29+
@test size(r) == size(x)
30+
@test r[:, 1:10] softmax(getgraph(g, 1).ndata.x, dims = 2)
31+
end
32+
33+
@testset "softmax_edges" begin
34+
r = softmax_edges(g, e)
35+
@test size(r) == size(e)
36+
@test r[:, 1:60] softmax(getgraph(g, 1).edata.e, dims = 2)
37+
end
38+
39+
@testset "broadcast_nodes" begin
40+
z = rand(4, g.num_graphs)
41+
r = broadcast_nodes(g, z)
42+
@test size(r) == (4, g.num_nodes)
43+
@test r[:, 1] z[:, 1]
44+
@test r[:, 10] z[:, 1]
45+
@test r[:, 11] z[:, 2]
46+
end
47+
48+
@testset "broadcast_edges" begin
49+
z = rand(4, g.num_graphs)
50+
r = broadcast_edges(g, z)
51+
@test size(r) == (4, g.num_edges)
52+
@test r[:, 1] z[:, 1]
53+
@test r[:, 60] z[:, 1]
54+
@test r[:, 61] z[:, 2]
55+
end
56+
57+
@testset "softmax_edge_neighbors" begin
58+
s = [1, 2, 3, 4]
59+
t = [5, 5, 6, 6]
60+
g2 = GNNGraph(s, t)
61+
e2 = randn(Float32, 3, g2.num_edges)
62+
z = softmax_edge_neighbors(g2, e2)
63+
@test size(z) == size(e2)
64+
@test z[:, 1:2] NNlib.softmax(e2[:, 1:2], dims = 2)
65+
@test z[:, 3:4] NNlib.softmax(e2[:, 3:4], dims = 2)
66+
end
67+
end
68+

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.6.20"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8-
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
98
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
109
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1110
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
@@ -27,7 +26,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2726
[compat]
2827
CUDA = "4, 5"
2928
ChainRulesCore = "1"
30-
DataStructures = "0.18"
3129
Flux = "0.14"
3230
Functors = "0.4.1"
3331
GNNGraphs = "1.0"

src/GraphNeuralNetworks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ using NNlib
99
using NNlib: scatter, gather
1010
using ChainRulesCore
1111
using Reexport
12-
using DataStructures: nlargest
1312
using MLUtils: zeros_like
1413

1514
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,

src/deprecations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# V1.0 deprecations
33
# TODO doe some reason this is not working
44
# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight)
5+
# @deprecate (l::GNNLayer)(gs::AbstractVector{<:GNNGraph}, args...; kws...) l(MLUtils.batch(gs), args...; kws...)

0 commit comments

Comments
 (0)