Skip to content

Commit cef511b

Browse files
add cora example
1 parent 4510d98 commit cef511b

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed

examples/cora.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# An example of semi-supervised node classification
2+
3+
using Flux
4+
using Flux: @functor, dropout, onecold, onehotbatch
5+
using Flux.Losses: logitcrossentropy
6+
using GraphNeuralNetworks
7+
using MLDatasets: Cora
8+
using Statistics, Random
9+
using CUDA
10+
CUDA.allowscalar(false)
11+
12+
struct GNN
13+
conv1
14+
conv2
15+
dense
16+
end
17+
18+
@functor GNN
19+
20+
function GNN(; nin, nhidden, nout)
21+
GNN(GCNConv(nin => nhidden, relu),
22+
GCNConv(nhidden => nhidden, relu),
23+
Dense(nhidden, nout))
24+
end
25+
26+
function (net::GNN)(fg, x)
27+
x = net.conv1(fg, x)
28+
x = dropout(x, 0.5)
29+
x = net.conv2(fg, x)
30+
x = net.dense(x)
31+
return x
32+
end
33+
34+
function eval_loss_accuracy(X, y, ids, model, fg)
35+
= model(fg, X)
36+
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
37+
acc = mean(onecold(ŷ[:,ids] |> cpu) .== onecold(y[:,ids] |> cpu))
38+
return (loss = l |> round4, acc = acc*100 |> round4)
39+
end
40+
41+
## utility functions
42+
num_params(model) = sum(length, Flux.params(model))
43+
round4(x) = round(x, digits=4)
44+
45+
# arguments for the `train` function
46+
Base.@kwdef mutable struct Args
47+
η = 1f-3 # learning rate
48+
epochs = 100 # number of epochs
49+
seed = 17 # set seed > 0 for reproducibility
50+
use_cuda = true # if true use cuda (if available)
51+
nhidden = 128 # dimension of hidden features
52+
infotime = 10 # report every `infotime` epochs
53+
end
54+
55+
function train(; kws...)
56+
args = Args(; kws...)
57+
if args.seed > 0
58+
Random.seed!(args.seed)
59+
CUDA.seed!(args.seed)
60+
end
61+
62+
if args.use_cuda && CUDA.functional()
63+
device = gpu
64+
@info "Training on GPU"
65+
else
66+
device = cpu
67+
@info "Training on CPU"
68+
end
69+
70+
data = Cora.dataset()
71+
fg = FeaturedGraph(data.adjacency_list)
72+
X = data.node_features |> device
73+
y = onehotbatch(data.node_labels, 1:data.num_classes) |> device
74+
train_ids = data.train_indices |> device
75+
val_ids = data.val_indices |> device
76+
test_ids = data.test_indices |> device
77+
78+
model = GNN(nin=size(X,1),
79+
nhidden=args.nhidden,
80+
nout=data.num_classes) |> device
81+
ps = Flux.params(model)
82+
opt = ADAM(args.η)
83+
84+
@info "NUM NODES: $(fg.num_nodes) NUM EDGES: $(fg.num_edges)"
85+
86+
function report(epoch)
87+
train = eval_loss_accuracy(X, y, train_ids, model, fg)
88+
val = eval_loss_accuracy(X, y, val_ids, model, fg)
89+
test = eval_loss_accuracy(X, y, test_ids, model, fg)
90+
println("Epoch: $epoch Train: $(train) Val: $(val) Test: $(test)")
91+
end
92+
93+
## TRAINING
94+
report(0)
95+
for epoch in 1:args.epochs
96+
gs = Flux.gradient(ps) do
97+
= model(fg, X)
98+
logitcrossentropy(ŷ[:,train_ids], y[:,train_ids])
99+
end
100+
101+
Flux.Optimise.update!(opt, ps, gs)
102+
103+
epoch % args.infotime == 0 && report(epoch)
104+
end
105+
106+
return fg, X, model, y, data
107+
end

0 commit comments

Comments
 (0)