1
1
# An example of semi-supervised node classification
2
2
3
3
using Flux
4
- using Flux: @functor , dropout, onecold, onehotbatch
4
+ using Flux: @functor , dropout, onecold, onehotbatch, getindex
5
5
using Flux. Losses: logitbinarycrossentropy
6
+ using Flux. Data: DataLoader
6
7
using GraphNeuralNetworks
7
8
using MLDatasets: TUDataset
8
9
using Statistics, Random
9
10
using CUDA
10
11
CUDA. allowscalar (false )
11
12
12
- function eval_loss_accuracy (model, g, X, y)
13
- ŷ = model (g, X) |> vec
14
- l = logitbinarycrossentropy (ŷ, y)
15
- acc = mean ((2 .* ŷ .- 1 ) .* (2 .* y .- 1 ) .> 0 )
16
- return (loss = round (l, digits= 4 ), acc = round (acc* 100 , digits= 2 ))
13
+ function eval_loss_accuracy (model, data_loader, device)
14
+ loss = 0.
15
+ acc = 0.
16
+ ntot = 0
17
+ for (g, X, y) in data_loader
18
+ g, X, y = g |> device, X |> device, y |> device
19
+ n = length (y)
20
+ ŷ = model (g, X) |> vec
21
+ loss += logitbinarycrossentropy (ŷ, y) * n
22
+ acc += mean ((2 .* ŷ .- 1 ) .* (2 .* y .- 1 ) .> 0 ) * n
23
+ ntot += n
24
+ end
25
+ return (loss = round (loss/ ntot, digits= 4 ), acc = round (acc* 100 / ntot, digits= 2 ))
17
26
end
18
27
19
28
struct GNNData
@@ -22,6 +31,16 @@ struct GNNData
22
31
y
23
32
end
24
33
34
+ Base. getindex (data:: GNNData , i:: Int ) = getindex (data, [i])
35
+
36
+ function Base. getindex (data:: GNNData , i:: AbstractVector )
37
+ sg, nodemap = subgraph (data. g, i)
38
+ return (sg, data. X[:,nodemap], data. y[i])
39
+ end
40
+
41
+ # Flux's Dataloader compatibility.
42
+ Flux. Data. _nobs (data:: GNNData ) = data. g. num_graphs
43
+ Flux. Data. _getobs (data:: GNNData , i) = data[i]
25
44
26
45
function getdataset (idxs)
27
46
data = TUDataset (" MUTAG" )[idxs]
37
56
# arguments for the `train` function
38
57
Base. @kwdef mutable struct Args
39
58
η = 1f-3 # learning rate
40
- epochs = 1000 # number of epochs
59
+ batchsize = 64 # batch size (number of graphs in each batch)
60
+ epochs = 200 # number of epochs
41
61
seed = 17 # set seed > 0 for reproducibility
42
- use_cuda = false # if true use cuda (if available)
62
+ usecuda = true # if true use cuda (if available)
43
63
nhidden = 128 # dimension of hidden features
44
64
infotime = 10 # report every `infotime` epochs
45
65
end
@@ -48,7 +68,7 @@ function train(; kws...)
48
68
args = Args (; kws... )
49
69
args. seed > 0 && Random. seed! (args. seed)
50
70
51
- if args. use_cuda && CUDA. functional ()
71
+ if args. usecuda && CUDA. functional ()
52
72
device = gpu
53
73
args. seed > 0 && CUDA. seed! (args. seed)
54
74
@info " Training on GPU"
@@ -61,12 +81,15 @@ function train(; kws...)
61
81
62
82
permindx = randperm (188 )
63
83
ntrain = 150
64
- gtrain, Xtrain, ytrain = getdataset (permindx[1 : ntrain])
65
- gtest, Xtest, ytest = getdataset (permindx[ntrain+ 1 : end ])
84
+ dtrain = getdataset (permindx[1 : ntrain])
85
+ dtest = getdataset (permindx[ntrain+ 1 : end ])
86
+
87
+ train_loader = DataLoader (dtrain, batchsize= args. batchsize, shuffle= true )
88
+ test_loader = DataLoader (dtest, batchsize= args. batchsize, shuffle= false )
66
89
67
90
# DEFINE MODEL
68
91
69
- nin = size (Xtrain, 1 )
92
+ nin = size (dtrain . X, 1 )
70
93
nhidden = args. nhidden
71
94
72
95
model = GNNChain (GCNConv (nin => nhidden, relu),
@@ -82,22 +105,23 @@ function train(; kws...)
82
105
# LOGGING FUNCTION
83
106
84
107
function report (epoch)
85
- train = eval_loss_accuracy (model, gtrain, Xtrain, ytrain )
86
- test = eval_loss_accuracy (model, gtest, Xtest, ytest )
108
+ train = eval_loss_accuracy (model, train_loader, device )
109
+ test = eval_loss_accuracy (model, test_loader, device )
87
110
println (" Epoch: $epoch Train: $(train) Test: $(test) " )
88
111
end
89
112
90
113
# TRAIN
91
114
92
115
report (0 )
93
116
for epoch in 1 : args. epochs
94
- # for (g, X, y) in train_loader
117
+ for (g, X, y) in train_loader
118
+ g, X, y = g |> device, X |> device, y |> device
95
119
gs = Flux. gradient (ps) do
96
- ŷ = model (gtrain, Xtrain ) |> vec
97
- logitbinarycrossentropy (ŷ, ytrain )
120
+ ŷ = model (g, X ) |> vec
121
+ logitbinarycrossentropy (ŷ, y )
98
122
end
99
123
Flux. Optimise. update! (opt, ps, gs)
100
- # end
124
+ end
101
125
102
126
epoch % args. infotime == 0 && report (epoch)
103
127
end
0 commit comments