Skip to content

Commit 68a3ad4

Browse files
clenaup
1 parent d383abf commit 68a3ad4

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/graph_classification_tudataset.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# An example of graph classification
22

33
using Flux
4-
using Flux: @functor, dropout, onecold, onehotbatch, getindex, cpu, gpu
4+
using Flux:onecold, onehotbatch
55
using Flux.Losses: logitbinarycrossentropy
66
using Flux.Data: DataLoader
77
using GraphNeuralNetworks
88
using MLDatasets: TUDataset
99
using Statistics, Random
10-
using LearnBase: getobs
1110
using CUDA
1211
CUDA.allowscalar(false)
1312

@@ -76,8 +75,8 @@ function train(; kws...)
7675
@info gfull
7776

7877
perm = randperm(gfull.num_graphs)
79-
gtrain = getobs(gfull, perm[1:NUM_TRAIN])
80-
gtest = getobs(gfull, perm[NUM_TRAIN+1:end])
78+
gtrain, _ = getgraph(gfull, perm[1:NUM_TRAIN])
79+
gtest, _ = getgraph(gfull, perm[NUM_TRAIN+1:end])
8180
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
8281
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
8382

@@ -121,4 +120,4 @@ function train(; kws...)
121120
end
122121
end
123122

124-
# train()
123+
train()

0 commit comments

Comments
 (0)