Skip to content

Commit 749e038

Browse files
Merge pull request #82 from CarloLucibello/cl/gf
GeometricFlux cora comparison
2 parents 91e93e0 + b349bc4 commit 749e038

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

examples/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
33
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
44
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
55
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
6+
GeometricFlux = "7e08b658-56d3-11e9-2997-919d5b31e4ea"
67
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
8+
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
79
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
810
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
911
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# An example of semi-supervised node classification
2+
3+
using Flux
4+
using Flux: onecold, onehotbatch
5+
using Flux.Losses: logitcrossentropy
6+
using GeometricFlux, GraphSignals
7+
using MLDatasets: Cora
8+
using Statistics, Random
9+
using CUDA
10+
CUDA.allowscalar(false)
11+
12+
function eval_loss_accuracy(X, y, ids, model)
13+
= model(X)
14+
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
15+
acc = mean(onecold(ŷ[:,ids]) .== onecold(y[:,ids]))
16+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
17+
end
18+
19+
# arguments for the `train` function
20+
Base.@kwdef mutable struct Args
21+
η = 1f-3 # learning rate
22+
epochs = 100 # number of epochs
23+
seed = 17 # set seed > 0 for reproducibility
24+
usecuda = true # if true use cuda (if available)
25+
nhidden = 128 # dimension of hidden features
26+
infotime = 10 # report every `infotime` epochs
27+
end
28+
29+
function train(; kws...)
30+
args = Args(; kws...)
31+
32+
args.seed > 0 && Random.seed!(args.seed)
33+
34+
if args.usecuda && CUDA.functional()
35+
device = gpu
36+
args.seed > 0 && CUDA.seed!(args.seed)
37+
@info "Training on GPU"
38+
else
39+
device = cpu
40+
@info "Training on CPU"
41+
end
42+
43+
# LOAD DATA
44+
data = Cora.dataset()
45+
g = FeaturedGraph(data.adjacency_list) |> device
46+
X = data.node_features |> device
47+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
48+
train_ids = data.train_indices |> device
49+
val_ids = data.val_indices |> device
50+
test_ids = data.test_indices |> device
51+
ytrain = y[:,train_ids]
52+
53+
nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes
54+
55+
## DEFINE MODEL
56+
model = Chain(GCNConv(g, nin => nhidden, relu),
57+
Dropout(0.5),
58+
GCNConv(g, nhidden => nhidden, relu),
59+
Dense(nhidden, nout)) |> device
60+
61+
ps = Flux.params(model)
62+
opt = ADAM(args.η)
63+
64+
@info g
65+
66+
## LOGGING FUNCTION
67+
function report(epoch)
68+
train = eval_loss_accuracy(X, y, train_ids, model)
69+
test = eval_loss_accuracy(X, y, test_ids, model)
70+
println("Epoch: $epoch Train: $(train) Test: $(test)")
71+
end
72+
73+
## TRAINING
74+
report(0)
75+
for epoch in 1:args.epochs
76+
gs = Flux.gradient(ps) do
77+
= model(X)
78+
logitcrossentropy(ŷ[:,train_ids], ytrain)
79+
end
80+
81+
Flux.Optimise.update!(opt, ps, gs)
82+
83+
epoch % args.infotime == 0 && report(epoch)
84+
end
85+
end
86+
87+
train(usecuda=false)

0 commit comments

Comments
 (0)