Skip to content

Commit b2887b5

Browse files
working tudataset example on cpu
1 parent 15d230b commit b2887b5

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

examples/graph_classification_mutag.jl renamed to examples/graph_classification_tudataset.jl

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
# An example of semi-supervised node classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch
4+
using Flux: @functor, dropout, onecold, onehotbatch, getindex
55
using Flux.Losses: logitbinarycrossentropy
6+
using Flux.Data: DataLoader
67
using GraphNeuralNetworks
78
using MLDatasets: TUDataset
89
using Statistics, Random
910
using CUDA
1011
CUDA.allowscalar(false)
1112

12-
function eval_loss_accuracy(model, g, X, y)
13-
= model(g, X) |> vec
14-
l = logitbinarycrossentropy(ŷ, y)
15-
acc = mean((2 .*.- 1) .* (2 .* y .- 1) .> 0)
16-
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
13+
function eval_loss_accuracy(model, data_loader, device)
14+
loss = 0.
15+
acc = 0.
16+
ntot = 0
17+
for (g, X, y) in data_loader
18+
g, X, y = g |> device, X |> device, y |> device
19+
n = length(y)
20+
= model(g, X) |> vec
21+
loss += logitbinarycrossentropy(ŷ, y) * n
22+
acc += mean((2 .*.- 1) .* (2 .* y .- 1) .> 0) * n
23+
ntot += n
24+
end
25+
return (loss = round(loss/ntot, digits=4), acc = round(acc*100/ntot, digits=2))
1726
end
1827

1928
struct GNNData
@@ -22,6 +31,16 @@ struct GNNData
2231
y
2332
end
2433

34+
Base.getindex(data::GNNData, i::Int) = getindex(data, [i])
35+
36+
function Base.getindex(data::GNNData, i::AbstractVector)
37+
sg, nodemap = subgraph(data.g, i)
38+
return (sg, data.X[:,nodemap], data.y[i])
39+
end
40+
41+
# Flux's Dataloader compatibility.
42+
Flux.Data._nobs(data::GNNData) = data.g.num_graphs
43+
Flux.Data._getobs(data::GNNData, i) = data[i]
2544

2645
function getdataset(idxs)
2746
data = TUDataset("MUTAG")[idxs]
@@ -37,9 +56,10 @@ end
3756
# arguments for the `train` function
3857
Base.@kwdef mutable struct Args
3958
η = 1f-3 # learning rate
40-
epochs = 1000 # number of epochs
59+
batchsize = 64 # batch size (number of graphs in each batch)
60+
epochs = 200 # number of epochs
4161
seed = 17 # set seed > 0 for reproducibility
42-
use_cuda = false # if true use cuda (if available)
62+
usecuda = true # if true use cuda (if available)
4363
nhidden = 128 # dimension of hidden features
4464
infotime = 10 # report every `infotime` epochs
4565
end
@@ -48,7 +68,7 @@ function train(; kws...)
4868
args = Args(; kws...)
4969
args.seed > 0 && Random.seed!(args.seed)
5070

51-
if args.use_cuda && CUDA.functional()
71+
if args.usecuda && CUDA.functional()
5272
device = gpu
5373
args.seed > 0 && CUDA.seed!(args.seed)
5474
@info "Training on GPU"
@@ -61,12 +81,15 @@ function train(; kws...)
6181

6282
permindx = randperm(188)
6383
ntrain = 150
64-
gtrain, Xtrain, ytrain = getdataset(permindx[1:ntrain])
65-
gtest, Xtest, ytest = getdataset(permindx[ntrain+1:end])
84+
dtrain = getdataset(permindx[1:ntrain])
85+
dtest = getdataset(permindx[ntrain+1:end])
86+
87+
train_loader = DataLoader(dtrain, batchsize=args.batchsize, shuffle=true)
88+
test_loader = DataLoader(dtest, batchsize=args.batchsize, shuffle=false)
6689

6790
# DEFINE MODEL
6891

69-
nin = size(Xtrain,1)
92+
nin = size(dtrain.X, 1)
7093
nhidden = args.nhidden
7194

7295
model = GNNChain(GCNConv(nin => nhidden, relu),
@@ -82,22 +105,23 @@ function train(; kws...)
82105
# LOGGING FUNCTION
83106

84107
function report(epoch)
85-
train = eval_loss_accuracy(model, gtrain, Xtrain, ytrain)
86-
test = eval_loss_accuracy(model, gtest, Xtest, ytest)
108+
train = eval_loss_accuracy(model, train_loader, device)
109+
test = eval_loss_accuracy(model, test_loader, device)
87110
println("Epoch: $epoch Train: $(train) Test: $(test)")
88111
end
89112

90113
# TRAIN
91114

92115
report(0)
93116
for epoch in 1:args.epochs
94-
# for (g, X, y) in train_loader
117+
for (g, X, y) in train_loader
118+
g, X, y = g |> device, X |> device, y |> device
95119
gs = Flux.gradient(ps) do
96-
= model(gtrain, Xtrain) |> vec
97-
logitbinarycrossentropy(ŷ, ytrain)
120+
= model(g, X) |> vec
121+
logitbinarycrossentropy(ŷ, y)
98122
end
99123
Flux.Optimise.update!(opt, ps, gs)
100-
# end
124+
end
101125

102126
epoch % args.infotime == 0 && report(epoch)
103127
end

examples/node_classification_cora.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Base.@kwdef mutable struct Args
2121
η = 1f-3 # learning rate
2222
epochs = 100 # number of epochs
2323
seed = 17 # set seed > 0 for reproducibility
24-
use_cuda = true # if true use cuda (if available)
24+
usecuda = true # if true use cuda (if available)
2525
nhidden = 128 # dimension of hidden features
2626
infotime = 10 # report every `infotime` epochs
2727
end
@@ -33,7 +33,7 @@ function train(; kws...)
3333
CUDA.seed!(args.seed)
3434
end
3535

36-
if args.use_cuda && CUDA.functional()
36+
if args.usecuda && CUDA.functional()
3737
device = gpu
3838
@info "Training on GPU"
3939
else

0 commit comments

Comments
 (0)