Skip to content

Commit b3d528b

Browse files
conver more tests
1 parent 15ea38b commit b3d528b

File tree

11 files changed

+524
-449
lines changed

11 files changed

+524
-449
lines changed

GNNGraphs/Project.toml

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "1.4.1"
55

66
[deps]
7-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
87
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
98
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
109
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
@@ -28,15 +27,13 @@ GNNGraphsCUDAExt = "CUDA"
2827
GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
2928

3029
[compat]
31-
Adapt = "4"
3230
CUDA = "5"
3331
ChainRulesCore = "1"
3432
Functors = "0.5"
3533
Graphs = "1.4"
3634
KrylovKit = "0.8"
3735
LinearAlgebra = "1"
3836
MLDataDevices = "1.0"
39-
MLDatasets = "0.7"
4037
MLUtils = "0.4"
4138
NNlib = "0.9"
4239
NearestNeighbors = "0.4"
@@ -45,21 +42,4 @@ SimpleWeightedGraphs = "1.4.0"
4542
SparseArrays = "1"
4643
Statistics = "1"
4744
StatsBase = "0.34"
48-
cuDNN = "1"
4945
julia = "1.10"
50-
51-
[extras]
52-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
53-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
54-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
55-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
56-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
57-
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
58-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
59-
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
60-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
61-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
62-
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
63-
64-
[targets]
65-
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]

GNNGraphs/test/Project.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
[deps]
22
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
34
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
5+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
46
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
57
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
69
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
10+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
711
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
812
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
13+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
14+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
915
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1016
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1117
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
18+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
19+
20+
[compat]
21+
GPUArraysCore = "0.1"

GNNGraphs/test/chainrules.jl

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
@testset "dict constructor" begin
1+
@testitem "dict constructor" setup=[GraphsTestModule] begin
2+
using .GraphsTestModule
23
grad = gradient(1.) do x
34
d = Dict([:x => x, :y => 5]...)
45
return sum(d[:x].^2)
56
end[1]
67

78
@test grad == 2
89

9-
## BROKEN Constructors
10-
# grad = gradient(1.) do x
11-
# d = Dict([(:x => x), (:y => 5)])
12-
# return sum(d[:x].^2)
13-
# end[1]
14-
15-
# @test grad == 2
10+
grad = gradient(1.) do x
11+
d = Dict([:x => x, :y => 5])
12+
return sum(d[:x].^2)
13+
end[1]
1614

15+
@test grad == 2
1716

18-
# grad = gradient(1.) do x
19-
# d = Dict([(:x => x), (:y => 5)])
20-
# return sum(d[:x].^2)
21-
# end[1]
17+
grad = gradient(1.) do x
18+
d = Dict(:x => x, :y => 5)
19+
return sum(d[:x].^2)
20+
end[1]
2221

23-
# @test grad == 2
22+
@test grad == 2
2423
end

GNNGraphs/test/convert.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1-
if TEST_GPU
2-
@testset "to_coo(dense) on gpu" begin
3-
get_st(A) = GNNGraphs.to_coo(A)[1][1:2]
4-
get_val(A) = GNNGraphs.to_coo(A)[1][3]
1+
@testitem "to_coo(dense) on gpu" setup=[GraphsTestModule] tags=[:gpu] begin
2+
using .GraphsTestModule
3+
get_st(A) = GNNGraphs.to_coo(A)[1][1:2]
4+
get_val(A) = GNNGraphs.to_coo(A)[1][3]
5+
gpu = gpu_device(force=true)
6+
A = gpu([0 2 2; 2.0 0 2; 2 2 0])
57

6-
A = cu([0 2 2; 2.0 0 2; 2 2 0])
8+
y = get_val(A)
9+
@test y isa AbstractVector{Float32}
10+
@test get_device(y) == get_device(A)
11+
@test Array(y) [2, 2, 2, 2, 2, 2]
712

8-
y = get_val(A)
9-
@test y isa CuVector{Float32}
10-
@test Array(y) [2, 2, 2, 2, 2, 2]
13+
s, t = get_st(A)
14+
@test s isa AbstractVector{<:Integer}
15+
@test t isa AbstractVector{<:Integer}
16+
@test get_device(s) == get_device(A)
17+
@test get_device(t) == get_device(A)
18+
@test Array(s) == [2, 3, 1, 3, 1, 2]
19+
@test Array(t) == [1, 1, 2, 2, 3, 3]
1120

12-
s, t = get_st(A)
13-
@test s isa CuVector{<:Integer}
14-
@test t isa CuVector{<:Integer}
15-
@test Array(s) == [2, 3, 1, 3, 1, 2]
16-
@test Array(t) == [1, 1, 2, 2, 3, 3]
17-
18-
@test gradient(A -> sum(get_val(A)), A)[1] isa CuMatrix{Float32}
19-
end
21+
grad = gradient(A -> sum(get_val(A)), A)[1]
22+
@test grad isa AbstractMatrix{Float32}
23+
@test get_device(grad) == get_device(A)
2024
end

GNNGraphs/test/datastore.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
@testset "constructor" begin
2+
@testitem "constructor" begin
33
@test_throws AssertionError DataStore(10, (:x => rand(10), :y => rand(2, 4)))
44

55
@testset "keyword args" begin
@@ -13,7 +13,7 @@
1313
end
1414
end
1515

16-
@testset "getproperty / setproperty!" begin
16+
@testitem "getproperty / setproperty!" begin
1717
x = rand(10)
1818
ds = DataStore(10, (:x => x, :y => rand(2, 10)))
1919
@test ds.x == ds[:x] == x
@@ -25,14 +25,14 @@ end
2525
@test fill(DataStore(), 3) isa Vector
2626
end
2727

28-
@testset "setindex!" begin
28+
@testitem "setindex!" begin
2929
ds = DataStore(10)
3030
x = rand(10)
3131
@test (ds[:x] = x) == x # Tests setindex!
3232
@test ds.x == ds[:x] == x
3333
end
3434

35-
@testset "map" begin
35+
@testitem "map" begin
3636
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
3737
ds2 = map(x -> x .+ 1, ds)
3838
@test ds2.x == ds.x .+ 1
@@ -41,33 +41,34 @@ end
4141
@test_throws AssertionError ds2=map(x -> [x; x], ds)
4242
end
4343

44-
@testset "getdata / getn" begin
44+
@testitem "getdata / getn" begin
4545
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
46-
@test getdata(ds) == getfield(ds, :_data)
46+
@test GNNGraphs.getdata(ds) == getfield(ds, :_data)
4747
@test_throws KeyError ds.data
48-
@test getn(ds) == getfield(ds, :_n)
48+
@test GNNGraphs.getn(ds) == getfield(ds, :_n)
4949
@test_throws KeyError ds.n
5050
end
5151

52-
@testset "cat empty" begin
52+
@testitem "cat empty" begin
5353
ds1 = DataStore(2, (:x => rand(2)))
5454
ds2 = DataStore(1, (:x => rand(1)))
5555
dsempty = DataStore(0, (:x => rand(0)))
5656

5757
ds = GNNGraphs.cat_features(ds1, ds2)
58-
@test getn(ds) == 3
58+
@test GNNGraphs.getn(ds) == 3
5959
ds = GNNGraphs.cat_features(ds1, dsempty)
60-
@test getn(ds) == 2
60+
@test GNNGraphs.getn(ds) == 2
6161

6262
# issue #280
6363
g = GNNGraph([1], [2])
6464
h = add_edges(g, Int[], Int[]) # adds no edges
65-
@test getn(g.edata) == 1
66-
@test getn(h.edata) == 1
65+
@test GNNGraphs.getn(g.edata) == 1
66+
@test GNNGraphs.getn(h.edata) == 1
6767
end
6868

6969

70-
@testset "gradient" begin
70+
@testitem "gradient" setup=[GraphsTestModule] begin
71+
using .GraphsTestModule
7172
ds = DataStore(10, (:x => rand(10), :y => rand(2, 10)))
7273

7374
f1(ds) = sum(ds.x)
@@ -80,11 +81,12 @@ end
8081
@test grad == exp.(x)
8182
end
8283

83-
@testset "functor" begin
84+
@testitem "functor" begin
85+
using Functors
8486
ds = DataStore(10, (:x => zeros(10), :y => ones(2, 10)))
8587
p, re = Functors.functor(ds)
86-
@test p[1] === getn(ds)
87-
@test p[2] === getdata(ds)
88+
@test p[1] === GNNGraphs.getn(ds)
89+
@test p[2] === GNNGraphs.getdata(ds)
8890
@test ds == re(p)
8991

9092
ds2 = Functors.fmap(ds) do x

GNNGraphs/test/ext/SimpleWeightedGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
@testset "simple_weighted_graph" begin
1+
@testitem "simple_weighted_graph" begin
2+
using SimpleWeightedGraphs
23
srcs = [1, 2, 1]
34
dsts = [2, 3, 3]
45
wts = [0.5, 0.8, 2.0]

0 commit comments

Comments
 (0)