Skip to content

Commit a590094

Browse files
temporarily reintegrate GNNGraphs tests (#449)
* add back GNNGraph tests * reintegrate extensions * don't use DeviceUtils * import Flux
1 parent b1e5669 commit a590094

File tree

20 files changed

+2016
-4
lines changed

20 files changed

+2016
-4
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
GNNGraphs.iscuarray(x::AnyCuArray) = true
3+
4+
5+
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
6+
#TODO proper cuda friendly implementation
7+
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
8+
end

ext/GraphNeuralNetworksCUDAExt/GraphNeuralNetworksCUDAExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import GraphNeuralNetworks: propagate
99

1010
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
1111

12+
include("GNNGraphs/query.jl")
13+
include("GNNGraphs/transform.jl")
14+
include("GNNGraphs/utils.jl")
1215
include("msgpass.jl")
1316

1417
end #module
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module GraphNeuralNetworksSimpleWeightedGraphsExt
2+
3+
using GraphNeuralNetworks
4+
using Graphs
5+
using SimpleWeightedGraphs
6+
7+
function GraphNeuralNetworks.GNNGraph(g::T; kws...) where
8+
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
9+
return GNNGraph(g.weights, kws...)
10+
end
11+
12+
end #module

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
77
has_self_loops, is_directed
88
import NearestNeighbors
99
import NNlib
10+
import Flux
11+
using Flux: batch
1012
import StatsBase
1113
import KrylovKit
1214
using ChainRulesCore

src/GNNGraphs/transform.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,10 +1126,13 @@ function negative_sample(g::GNNGraph;
11261126

11271127
s, t = edge_index(g)
11281128
n = g.num_nodes
1129-
device = get_device(s)
1130-
cdevice = cpu_device()
1131-
# Convert to gpu since set operations and sampling are not supported by CUDA.jl
1132-
s, t = cdevice(s), cdevice(t)
1129+
if iscuarray(s)
1130+
# Convert to gpu since set operations and sampling are not supported by CUDA.jl
1131+
device = Flux.gpu
1132+
s, t = Flux.cpu(s), Flux.cpu(t)
1133+
else
1134+
device = Flux.cpu
1135+
end
11331136
idx_pos, maxid = edge_encoding(s, t, n)
11341137
if bidirected
11351138
num_neg_edges = num_neg_edges ÷ 2
@@ -1156,6 +1159,7 @@ function negative_sample(g::GNNGraph;
11561159
return GNNGraph(s_neg, t_neg, num_nodes = n) |> device
11571160
end
11581161

1162+
11591163
"""
11601164
rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g)) -> g1, g2
11611165

test/GNNGraphs/chainrules.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
@testset "dict constructor" begin
2+
grad = gradient(1.) do x
3+
d = Dict([:x => x, :y => 5]...)
4+
return sum(d[:x].^2)
5+
end[1]
6+
7+
@test grad == 2
8+
9+
## BROKEN Constructors
10+
# grad = gradient(1.) do x
11+
# d = Dict([(:x => x), (:y => 5)])
12+
# return sum(d[:x].^2)
13+
# end[1]
14+
15+
# @test grad == 2
16+
17+
18+
# grad = gradient(1.) do x
19+
# d = Dict([(:x => x), (:y => 5)])
20+
# return sum(d[:x].^2)
21+
# end[1]
22+
23+
# @test grad == 2
24+
end

test/GNNGraphs/convert.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
if TEST_GPU
2+
@testset "to_coo(dense) on gpu" begin
3+
get_st(A) = GNNGraphs.to_coo(A)[1][1:2]
4+
get_val(A) = GNNGraphs.to_coo(A)[1][3]
5+
6+
A = cu([0 2 2; 2.0 0 2; 2 2 0])
7+
8+
y = get_val(A)
9+
@test y isa CuVector{Float32}
10+
@test Array(y) [2, 2, 2, 2, 2, 2]
11+
12+
s, t = get_st(A)
13+
@test s isa CuVector{<:Integer}
14+
@test t isa CuVector{<:Integer}
15+
@test Array(s) == [2, 3, 1, 3, 1, 2]
16+
@test Array(t) == [1, 1, 2, 2, 3, 3]
17+
18+
@test gradient(A -> sum(get_val(A)), A)[1] isa CuMatrix{Float32}
19+
end
20+
end

test/GNNGraphs/datastore.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
@testset "constructor" begin
3+
@test_throws AssertionError DataStore(10, (:x => rand(10), :y => rand(2, 4)))
4+
5+
@testset "keyword args" begin
6+
ds = DataStore(10, x = rand(10), y = rand(2, 10))
7+
@test size(ds.x) == (10,)
8+
@test size(ds.y) == (2, 10)
9+
10+
ds = DataStore(x = rand(10), y = rand(2, 10))
11+
@test size(ds.x) == (10,)
12+
@test size(ds.y) == (2, 10)
13+
end
14+
end
15+
16+
@testset "getproperty / setproperty!" begin
17+
x = rand(10)
18+
ds = DataStore(10, (:x => x, :y => rand(2, 10)))
19+
@test ds.x == ds[:x] == x
20+
@test_throws DimensionMismatch ds.z=rand(12)
21+
ds.z = [1:10;]
22+
@test ds.z == [1:10;]
23+
vec = [DataStore(10, (:x => x,)), DataStore(10, (:x => x, :y => rand(2, 10)))]
24+
@test vec.x == [x, x]
25+
@test_throws KeyError vec.z
26+
@test vec._n == [10, 10]
27+
@test vec._data == [Dict(:x => x), Dict(:x => x, :y => vec[2].y)]
28+
end
29+
30+
@testset "setindex!" begin
31+
ds = DataStore(10)
32+
x = rand(10)
33+
@test (ds[:x] = x) == x # Tests setindex!
34+
@test ds.x == ds[:x] == x
35+
end
36+
37+
@testset "map" begin
38+
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
39+
ds2 = map(x -> x .+ 1, ds)
40+
@test ds2.x == ds.x .+ 1
41+
@test ds2.y == ds.y .+ 1
42+
43+
@test_throws AssertionError ds2=map(x -> [x; x], ds)
44+
end
45+
46+
@testset "getdata / getn" begin
47+
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
48+
@test getdata(ds) == getfield(ds, :_data)
49+
@test_throws KeyError ds.data
50+
@test getn(ds) == getfield(ds, :_n)
51+
@test_throws KeyError ds.n
52+
end
53+
54+
@testset "cat empty" begin
55+
ds1 = DataStore(2, (:x => rand(2)))
56+
ds2 = DataStore(1, (:x => rand(1)))
57+
dsempty = DataStore(0, (:x => rand(0)))
58+
59+
ds = GNNGraphs.cat_features(ds1, ds2)
60+
@test getn(ds) == 3
61+
ds = GNNGraphs.cat_features(ds1, dsempty)
62+
@test getn(ds) == 2
63+
64+
# issue #280
65+
g = GNNGraph([1], [2])
66+
h = add_edges(g, Int[], Int[]) # adds no edges
67+
@test getn(g.edata) == 1
68+
@test getn(h.edata) == 1
69+
end
70+
71+
72+
@testset "gradient" begin
73+
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
74+
75+
f1(ds) = sum(ds.x)
76+
grad = gradient(f1, ds)[1]
77+
@test grad._data[:x] ngradient(f1, ds)[1][:x]
78+
79+
g = rand_graph(5, 2)
80+
x = rand(2, 5)
81+
grad = gradient(x -> sum(exp, GNNGraph(g, ndata = x).ndata.x), x)[1]
82+
@test grad == exp.(x)
83+
end
84+
85+
@testset "functor" begin
86+
ds = DataStore(10, (:x => zeros(10), :y => ones(2, 10)))
87+
p, re = Functors.functor(ds)
88+
@test p[1] === getn(ds)
89+
@test p[2] === getdata(ds)
90+
@test ds == re(p)
91+
92+
ds2 = Functors.fmap(ds) do x
93+
if x isa AbstractArray
94+
x .+ 1
95+
else
96+
x
97+
end
98+
end
99+
@test ds isa DataStore
100+
@test ds2.x == ds.x .+ 1
101+
end

0 commit comments

Comments
 (0)