Skip to content

Commit c514d27

Browse files
Merge pull request #45 from CarloLucibello/cl/dev
add test for Cora example
2 parents e22afdc + 3585cca commit c514d27

File tree

8 files changed

+127
-29
lines changed

8 files changed

+127
-29
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ jobs:
2222
- ubuntu-latest
2323
arch:
2424
- x64
25+
env: # Don't use system Python (needed by PyCall)
26+
PYTHON: ""
2527
steps:
2628
- uses: actions/checkout@v2
2729
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ julia = "1.6"
3636
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3737
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3838
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
39+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
3940
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4041
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4142

4243
[targets]
43-
test = ["Test", "Adapt", "Zygote", "FiniteDifferences", "ChainRulesTestUtils"]
44+
test = ["Test", "Adapt", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets"]

examples/graph_classification_tudataset.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# An example of graph classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch, getindex
4+
using Flux:onecold, onehotbatch
55
using Flux.Losses: logitbinarycrossentropy
66
using Flux.Data: DataLoader
77
using GraphNeuralNetworks
88
using MLDatasets: TUDataset
99
using Statistics, Random
10-
using LearnBase: getobs
1110
using CUDA
1211
CUDA.allowscalar(false)
1312

@@ -76,8 +75,8 @@ function train(; kws...)
7675
@info gfull
7776

7877
perm = randperm(gfull.num_graphs)
79-
gtrain = getobs(gfull, perm[1:NUM_TRAIN])
80-
gtest = getobs(gfull, perm[NUM_TRAIN+1:end])
78+
gtrain, _ = getgraph(gfull, perm[1:NUM_TRAIN])
79+
gtest, _ = getgraph(gfull, perm[NUM_TRAIN+1:end])
8180
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
8281
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
8382

@@ -121,4 +120,4 @@ function train(; kws...)
121120
end
122121
end
123122

124-
# train()
123+
train()

examples/node_classification_cora.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# An example of semi-supervised node classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch
4+
using Flux: onecold, onehotbatch
55
using Flux.Losses: logitcrossentropy
66
using GraphNeuralNetworks
77
using MLDatasets: Cora
@@ -28,13 +28,12 @@ end
2828

2929
function train(; kws...)
3030
args = Args(; kws...)
31-
if args.seed > 0
32-
Random.seed!(args.seed)
33-
CUDA.seed!(args.seed)
34-
end
31+
32+
args.seed > 0 && Random.seed!(args.seed)
3533

3634
if args.usecuda && CUDA.functional()
3735
device = gpu
36+
args.seed > 0 && CUDA.seed!(args.seed)
3837
@info "Training on GPU"
3938
else
4039
device = cpu

src/layers/conv.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,8 @@ end
288288

289289

290290
function Base.show(io::IO, l::GATConv)
291-
in_channel = size(l.weight, ndims(l.weight))
292-
out_channel = size(l.weight, ndims(l.weight)-1)
293-
print(io, "GATConv(", in_channel, "=>", out_channel)
291+
out_channel, in_channel = size(l.weight)
292+
print(io, "GATConv(", in_channel, "=>", out_channel ÷ l.heads)
294293
print(io, ", LeakyReLU(λ=", l.negative_slope)
295294
print(io, "))")
296295
end
@@ -341,7 +340,7 @@ update_node(l::GatedGraphConv, m, x) = m
341340
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
342341
@non_differentiable fill!(x...)
343342

344-
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
343+
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
345344
check_num_nodes(g, H)
346345
m, n = size(H)
347346
@assert (m <= l.out_ch) "number of input features must less or equals to output features."
@@ -567,9 +566,8 @@ function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
567566
end
568567

569568
function Base.show(io::IO, l::SAGEConv)
570-
in_channel = size(l.weight1, ndims(l.weight1))
571-
out_channel = size(l.weight1, ndims(l.weight1)-1)
572-
print(io, "SAGEConv(", in_channel, " => ", out_channel)
569+
out_channel, in_channel = size(l.weight)
570+
print(io, "SAGEConv(", in_channel ÷ 2, " => ", out_channel)
573571
l.σ == identity || print(io, ", ", l.σ)
574572
print(io, ", aggr=", l.aggr)
575573
print(io, ")")
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using Flux
2+
using Flux: onecold, onehotbatch
3+
using Flux.Losses: logitcrossentropy
4+
using GraphNeuralNetworks
5+
using MLDatasets: Cora
6+
using Statistics, Random
7+
using CUDA
8+
CUDA.allowscalar(false)
9+
10+
function eval_loss_accuracy(X, y, ids, model, g)
11+
= model(g, X)
12+
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
13+
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
14+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
15+
end
16+
17+
18+
# arguments for the `train` function
19+
Base.@kwdef mutable struct Args
20+
η = 1f-3 # learning rate
21+
epochs = 20 # number of epochs
22+
seed = 17 # set seed > 0 for reproducibility
23+
usecuda = false # if true use cuda (if available)
24+
nhidden = 128 # dimension of hidden features
25+
end
26+
27+
function train(Layer; verbose=false, kws...)
28+
args = Args(; kws...)
29+
args.seed > 0 && Random.seed!(args.seed)
30+
31+
if args.usecuda && CUDA.functional()
32+
device = Flux.gpu
33+
args.seed > 0 && CUDA.seed!(args.seed)
34+
else
35+
device = Flux.cpu
36+
end
37+
38+
# LOAD DATA
39+
data = Cora.dataset()
40+
g = GNNGraph(data.adjacency_list) |> device
41+
X = data.node_features |> device
42+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
43+
train_ids = data.train_indices |> device
44+
val_ids = data.val_indices |> device
45+
test_ids = data.test_indices |> device
46+
ytrain = y[:,train_ids]
47+
48+
nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes
49+
50+
## DEFINE MODEL
51+
model = GNNChain(Layer(nin, nhidden),
52+
Dropout(0.5),
53+
Layer(nhidden, nhidden),
54+
Dense(nhidden, nout)) |> device
55+
56+
ps = Flux.params(model)
57+
opt = ADAM(args.η)
58+
59+
60+
## TRAINING
61+
function report(epoch)
62+
train = eval_loss_accuracy(X, y, train_ids, model, g)
63+
test = eval_loss_accuracy(X, y, test_ids, model, g)
64+
println("Epoch: $epoch Train: $(train) Test: $(test)")
65+
end
66+
67+
verbose && report(0)
68+
for epoch in 1:args.epochs
69+
gs = Flux.gradient(ps) do
70+
= model(g, X)
71+
logitcrossentropy(ŷ[:,train_ids], ytrain)
72+
end
73+
verbose && report(epoch)
74+
Flux.Optimise.update!(opt, ps, gs)
75+
end
76+
77+
train_res = eval_loss_accuracy(X, y, train_ids, model, g)
78+
test_res = eval_loss_accuracy(X, y, test_ids, model, g)
79+
return train_res, test_res
80+
end
81+
82+
for Layer in [
83+
(nin, nout) -> GCNConv(nin => nout, relu),
84+
(nin, nout) -> GraphConv(nin => nout, relu, aggr=mean),
85+
(nin, nout) -> SAGEConv(nin => nout, relu),
86+
(nin, nout) -> GATConv(nin => nout, relu),
87+
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
88+
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
89+
(nin, nout) -> ChebConv(nin => nout, 3),
90+
# (nin, nout) -> NNConv(nin => nout), # needs edge features
91+
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
92+
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
93+
]
94+
train_res, test_res = train(Layer, verbose=true)
95+
# @show Layer(2,2) train_res, test_res
96+
@test train_res.acc > 95
97+
@test test_res.acc > 70
98+
end

test/layers/conv.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,17 @@
4040
end
4141

4242
@testset "ChebConv" begin
43-
k = 6
43+
k = 3
4444
l = ChebConv(in_channel => out_channel, k)
4545
@test size(l.weight) == (out_channel, in_channel, k)
4646
@test size(l.bias) == (out_channel,)
4747
@test l.k == k
4848
for g in test_graphs
49-
if g === g_single_vertex && GRAPH_T == :dense
50-
@test_broken test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=false)
51-
else
52-
test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=false)
53-
if TEST_GPU
54-
@test_broken test_layer(l, g, rtol=1e-5, broken_grad_fields=[:weight], test_gpu=true)
55-
end
56-
end
49+
g = add_self_loops(g)
50+
test_layer(l, g, rtol=1e-5, test_gpu=false, outsize=(out_channel, g.num_nodes))
51+
if TEST_GPU
52+
@test_broken test_layer(l, g, rtol=1e-5, test_gpu=true, outsize=(out_channel, g.num_nodes))
53+
end
5754
end
5855

5956
@testset "bias=false" begin
@@ -81,10 +78,10 @@
8178

8279
@testset "GATConv" begin
8380

84-
for heads in (1, 3), concat in (true, false)
81+
for heads in (1, 2), concat in (true, false)
8582
l = GATConv(in_channel => out_channel; heads, concat)
8683
for g in test_graphs
87-
test_layer(l, g, rtol=1e-4,
84+
test_layer(l, g, rtol=1e-4,
8885
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
8986
end
9087
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ using LearnBase
88
using LightGraphs
99
using Zygote
1010
using Test
11+
using MLDatasets
1112
CUDA.allowscalar(false)
1213

14+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets
15+
1316
include("test_utils.jl")
1417

1518
tests = [
@@ -18,6 +21,7 @@ tests = [
1821
"layers/basic",
1922
"layers/conv",
2023
"layers/pool",
24+
"examples/node_classification_cora",
2125
]
2226

2327
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)