Skip to content

Commit 30585a3

Browse files
test examples
1 parent 56acbeb commit 30585a3

File tree

7 files changed

+112
-14
lines changed

7 files changed

+112
-14
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1313
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1516
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1617
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1718
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -36,8 +37,9 @@ julia = "1.6"
3637
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3738
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3839
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
40+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
3941
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4042
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4143

4244
[targets]
43-
test = ["Test", "Adapt", "Zygote", "FiniteDifferences", "ChainRulesTestUtils"]
45+
test = ["Test", "Adapt", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets"]

examples/graph_classification_tudataset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# An example of graph classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch, getindex
4+
using Flux: @functor, dropout, onecold, onehotbatch, getindex, cpu, gpu
55
using Flux.Losses: logitbinarycrossentropy
66
using Flux.Data: DataLoader
77
using GraphNeuralNetworks

examples/node_classification_cora.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# An example of semi-supervised node classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch
4+
using Flux: onecold, onehotbatch
55
using Flux.Losses: logitcrossentropy
66
using GraphNeuralNetworks
77
using MLDatasets: Cora
@@ -28,13 +28,12 @@ end
2828

2929
function train(; kws...)
3030
args = Args(; kws...)
31-
if args.seed > 0
32-
Random.seed!(args.seed)
33-
CUDA.seed!(args.seed)
34-
end
31+
32+
args.seed > 0 && Random.seed!(args.seed)
3533

3634
if args.usecuda && CUDA.functional()
3735
device = gpu
36+
args.seed > 0 && CUDA.seed!(args.seed)
3837
@info "Training on GPU"
3938
else
4039
device = cpu

src/layers/conv.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,8 @@ end
288288

289289

290290
function Base.show(io::IO, l::GATConv)
291-
in_channel = size(l.weight, ndims(l.weight))
292-
out_channel = size(l.weight, ndims(l.weight)-1)
293-
print(io, "GATConv(", in_channel, "=>", out_channel)
291+
out_channel, in_channel = size(l.weight)
292+
print(io, "GATConv(", in_channel, "=>", out_channel ÷ l.heads)
294293
print(io, ", LeakyReLU(λ=", l.negative_slope)
295294
print(io, "))")
296295
end
@@ -567,9 +566,8 @@ function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
567566
end
568567

569568
function Base.show(io::IO, l::SAGEConv)
570-
in_channel = size(l.weight1, ndims(l.weight1))
571-
out_channel = size(l.weight1, ndims(l.weight1)-1)
572-
print(io, "SAGEConv(", in_channel, " => ", out_channel)
569+
out_channel, in_channel = size(l.weight)
570+
print(io, "SAGEConv(", in_channel ÷ 2, " => ", out_channel)
573571
l.σ == identity || print(io, ", ", l.σ)
574572
print(io, ", aggr=", l.aggr)
575573
print(io, ")")
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using Flux
2+
using Flux: onecold, onehotbatch
3+
using Flux.Losses: logitcrossentropy
4+
using GraphNeuralNetworks
5+
using MLDatasets: Cora
6+
using Statistics, Random
7+
using CUDA
8+
CUDA.allowscalar(false)
9+
10+
function eval_loss_accuracy(X, y, ids, model, g)
11+
= model(g, X)
12+
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
13+
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
14+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
15+
end
16+
17+
18+
# arguments for the `train` function
19+
Base.@kwdef mutable struct Args
20+
η = 1f-3 # learning rate
21+
epochs = 20 # number of epochs
22+
seed = 17 # set seed > 0 for reproducibility
23+
usecuda = false # if true use cuda (if available)
24+
nhidden = 128 # dimension of hidden features
25+
end
26+
27+
function train(Layer; verbose=false, kws...)
28+
args = Args(; kws...)
29+
args.seed > 0 && Random.seed!(args.seed)
30+
31+
if args.usecuda && CUDA.functional()
32+
device = Flux.gpu
33+
args.seed > 0 && CUDA.seed!(args.seed)
34+
else
35+
device = Flux.cpu
36+
end
37+
38+
# LOAD DATA
39+
data = Cora.dataset()
40+
g = GNNGraph(data.adjacency_list) |> device
41+
X = data.node_features |> device
42+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
43+
train_ids = data.train_indices |> device
44+
val_ids = data.val_indices |> device
45+
test_ids = data.test_indices |> device
46+
ytrain = y[:,train_ids]
47+
48+
nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes
49+
50+
## DEFINE MODEL
51+
model = GNNChain(Layer(nin, nhidden),
52+
Dropout(0.5),
53+
Layer(nhidden, nhidden),
54+
Dense(nhidden, nout)) |> device
55+
56+
ps = Flux.params(model)
57+
opt = ADAM(args.η)
58+
59+
60+
## TRAINING
61+
function report(epoch)
62+
train = eval_loss_accuracy(X, y, train_ids, model, g)
63+
test = eval_loss_accuracy(X, y, test_ids, model, g)
64+
println("Epoch: $epoch Train: $(train) Test: $(test)")
65+
end
66+
67+
verbose && report(0)
68+
for epoch in 1:args.epochs
69+
gs = Flux.gradient(ps) do
70+
= model(g, X)
71+
logitcrossentropy(ŷ[:,train_ids], ytrain)
72+
end
73+
verbose && report(epoch)
74+
Flux.Optimise.update!(opt, ps, gs)
75+
end
76+
77+
train_res = eval_loss_accuracy(X, y, train_ids, model, g)
78+
test_res = eval_loss_accuracy(X, y, test_ids, model, g)
79+
return train_res, test_res
80+
end
81+
82+
for Layer in [
83+
(nin, nout) -> GCNConv(nin => nout, relu),
84+
(nin, nout) -> GraphConv(nin => nout, relu, aggr=mean),
85+
(nin, nout) -> SAGEConv(nin => nout, relu),
86+
(nin, nout) -> GATConv(nin => nout, relu),
87+
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
88+
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
89+
(nin, nout) -> ChebConv(nin => nout, 3),
90+
# (nin, nout) -> NNConv(nin => nout), # needs edge features
91+
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
92+
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
93+
]
94+
train_res, test_res = train(Layer, verbose=true)
95+
# @show Layer(2,2) train_res, test_res
96+
@test train_res.acc > 95
97+
@test test_res.acc > 70
98+
end

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
end
4141

4242
@testset "ChebConv" begin
43-
k = 6
43+
k = 3
4444
l = ChebConv(in_channel => out_channel, k)
4545
@test size(l.weight) == (out_channel, in_channel, k)
4646
@test size(l.bias) == (out_channel,)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tests = [
1818
"layers/basic",
1919
"layers/conv",
2020
"layers/pool",
21+
"examples.jl",
2122
]
2223

2324
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)