Skip to content

Commit edef2c6

Browse files
dataloader support; working gpu
1 parent b2887b5 commit edef2c6

File tree

5 files changed

+41
-15
lines changed

5 files changed

+41
-15
lines changed

examples/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
44
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
55
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
66
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
7+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
8+
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

examples/graph_classification_tudataset.jl

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,14 @@ function Base.getindex(data::GNNData, i::AbstractVector)
3838
return (sg, data.X[:,nodemap], data.y[i])
3939
end
4040

41-
# Flux's Dataloader compatibility.
41+
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
4242
Flux.Data._nobs(data::GNNData) = data.g.num_graphs
4343
Flux.Data._getobs(data::GNNData, i) = data[i]
4444

45-
function getdataset(idxs)
46-
data = TUDataset("MUTAG")[idxs]
47-
@info "MUTAG: num_nodes: $(data.num_nodes) num_edges: $(data.num_edges) num_graphs: $(data.num_graphs)"
45+
function process_dataset(data)
4846
g = GNNGraph(data.source, data.target, num_nodes=data.num_nodes, graph_indicator=data.graph_indicator)
4947
X = Array{Float32}(onehotbatch(data.node_labels, 0:6))
48+
# The dataset also has edge features but we won't be using them
5049
# E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
5150
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
5251
@assert all(([0,1]), y) # binary classification
@@ -78,12 +77,18 @@ function train(; kws...)
7877
end
7978

8079
# LOAD DATA
80+
81+
NUM_TRAIN = 150
82+
full_data = TUDataset("MUTAG")
8183

82-
permindx = randperm(188)
83-
ntrain = 150
84-
dtrain = getdataset(permindx[1:ntrain])
85-
dtest = getdataset(permindx[ntrain+1:end])
84+
@info "MUTAG DATASET
85+
num_nodes: $(full_data.num_nodes)
86+
num_edges: $(full_data.num_edges)
87+
num_graphs: $(full_data.num_graphs)"
8688

89+
perm = randperm(full_data.num_graphs)
90+
dtrain = process_dataset(full_data[perm[1:NUM_TRAIN]])
91+
dtest = process_dataset(full_data[perm[NUM_TRAIN+1:end]])
8792
train_loader = DataLoader(dtrain, batchsize=args.batchsize, shuffle=true)
8893
test_loader = DataLoader(dtest, batchsize=args.batchsize, shuffle=false)
8994

@@ -92,9 +97,9 @@ function train(; kws...)
9297
nin = size(dtrain.X, 1)
9398
nhidden = args.nhidden
9499

95-
model = GNNChain(GCNConv(nin => nhidden, relu),
100+
model = GNNChain(GraphConv(nin => nhidden, relu),
96101
Dropout(0.5),
97-
GCNConv(nhidden => nhidden, relu),
102+
GraphConv(nhidden => nhidden, relu),
98103
GlobalPool(mean),
99104
Dense(nhidden, 1)) |> device
100105

@@ -127,4 +132,4 @@ function train(; kws...)
127132
end
128133
end
129134

130-
# train()
135+
train()

src/gnngraph.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
490490
s = [nodemap[i] for i in s[edge_mask]]
491491
t = [nodemap[i] for i in t[edge_mask]]
492492
w = isnothing(w) ? nothing : w[edge_mask]
493-
@show size(g.nf) size(node_mask)
494493
nf = isnothing(g.nf) ? nothing : g.nf[:,node_mask]
495494
ef = isnothing(g.ef) ? nothing : g.ef[:,edge_mask]
496495
gf = isnothing(g.gf) ? nothing : g.gf[:,i]

test/gnngraph.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
@test sort(outneighbors(g, 1)) == [2, 4]
1818
@test sort(inneighbors(g, 1)) == [2, 4]
1919
@test is_directed(g) == true
20-
s1, t1 = sort_edge_index(edge_index(g))
20+
s1, t1 = GraphNeuralNetworks.sort_edge_index(edge_index(g))
2121
@test s1 == s
2222
@test t1 == t
2323

@@ -65,7 +65,7 @@
6565
@test sort(outneighbors(g, 1)) == [2]
6666
@test sort(inneighbors(g, 1)) == [4]
6767
@test is_directed(g) == true
68-
s1, t1 = sort_edge_index(edge_index(g))
68+
s1, t1 = GraphNeuralNetworks.sort_edge_index(edge_index(g))
6969
@test s1 == s
7070
@test t1 == t
7171

@@ -103,6 +103,7 @@
103103
end
104104

105105
@testset "batch" begin
106+
#TODO add graph_type=GRAPH_T
106107
g1 = GNNGraph(random_regular_graph(10,2), nf=rand(16,10))
107108
g2 = GNNGraph(random_regular_graph(4,2), nf=rand(16,4))
108109
g3 = GNNGraph(random_regular_graph(7,2), nf=rand(16,7))
@@ -112,5 +113,25 @@
112113

113114
g123 = Flux.batch([g1, g2, g3])
114115
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
116+
117+
s, t = edge_index(g123)
118+
@test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]]
119+
@test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]]
120+
@test g123.nf[:,11:14] g2.nf
121+
end
122+
123+
@testset "subgraph" begin
124+
#TODO add graph_type=GRAPH_T
125+
g1 = GNNGraph(random_regular_graph(10,2), nf=rand(16,10))
126+
g2 = GNNGraph(random_regular_graph(4,2), nf=rand(16,4))
127+
g3 = GNNGraph(random_regular_graph(7,2), nf=rand(16,7))
128+
g = Flux.batch([g1, g2, g3])
129+
g2b, nodemap = subgraph(g, 2)
130+
131+
s, t = edge_index(g2b)
132+
@test s == edge_index(g2)[1]
133+
@test t == edge_index(g2)[2]
134+
@test g2b.nf g2.nf
115135
end
136+
116137
end

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using GraphNeuralNetworks
2-
using GraphNeuralNetworks: sort_edge_index
32
using Flux
43
using CUDA
54
using Flux: gpu, @functor

0 commit comments

Comments
 (0)