Skip to content

Commit 401147e

Browse files
dataloader support
1 parent b5512db commit 401147e

File tree

9 files changed

+219
-89
lines changed

9 files changed

+219
-89
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1616
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1717
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
18+
PrettyPrint = "8162dcfd-2161-5ef2-ae6c-7681170c5f98"
1819
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Documenter
44
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive=true)
55

66
makedocs(;
7-
modules=[GraphNeuralNetworks],
7+
modules=[GraphNeuralNetworks, NNlib],
88
sitename = "GraphNeuralNetworks.jl",
99
pages = ["Home" => "index.md",
1010
"GNNGraph" => "gnngraph.md",

docs/src/gnngraph.md

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,78 @@
22

33
TODO
44

5-
```@docs
6-
GNNGraph
5+
## Graph Creation
6+
7+
8+
```julia
9+
using GraphNeuralNetworks, LightGraphs, SparseArrays
10+
11+
12+
# From LightGraphs's graph
13+
lg_graph = erdos_renyi(10, 0.3)
14+
g = GNNGraph(lg_graph)
15+
16+
17+
# From adjacency matrix
18+
A = sprand(10, 10, 0.3)
19+
20+
g = GNNGraph(A)
21+
22+
@assert adjacency_matrix(g) == A
23+
24+
# From adjacency list
25+
adjlist = [[] [] [] ]
26+
27+
g = GNNGraph(adjlist)
28+
29+
@assert sort.(adjacency_list(g)) == sort.(adjlist)
30+
31+
# From COO representation
32+
source = []
33+
target = []
34+
g = GNNGraph(source, target)
35+
@assert edge_index(g) == (source, target)
36+
```
37+
38+
We have also seen some useful methods such as [`adjacency_matrix`](@ref) and [`edge_index`](@ref).
39+
40+
41+
42+
## Data Features
43+
44+
```julia
45+
GNNGraph(sprand(10, 0.3), ndata = (; X=rand(32, 10)))
46+
# or equivalently
47+
GNNGraph(sprand(10, 0.3), ndata=rand(32, 10))
48+
49+
50+
g = GNNGraph(sprand(10, 0.3), ndata = (X=rand(32, 10), y=rand(10)))
51+
52+
g = GNNGraph(g, edata=rand(6, g.num_edges))
53+
```
54+
55+
56+
## Graph Manipulation
57+
58+
```julia
59+
g = add_self_loops(g)
60+
61+
g = remove_self_loops(g)
62+
```
63+
64+
## Batches and Subgraphs
65+
66+
```julia
67+
g = Flux.batch([g1, g2, g3])
68+
69+
subgraph(g, 2:3)
70+
```
71+
72+
73+
## LightGraphs integration
74+
75+
```julia
76+
@assert LightGraphs.isdirected(g)
777
```
878

79+
## Other methods

examples/graph_classification_tudataset.jl

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,39 @@ using Flux.Data: DataLoader
77
using GraphNeuralNetworks
88
using MLDatasets: TUDataset
99
using Statistics, Random
10+
using LearnBase: getobs
1011
using CUDA
1112
CUDA.allowscalar(false)
1213

1314
function eval_loss_accuracy(model, data_loader, device)
1415
loss = 0.
1516
acc = 0.
1617
ntot = 0
17-
for (g, X, y) in data_loader
18-
g, X, y = g |> device, X |> device, y |> device
19-
n = length(y)
20-
= model(g, X) |> vec
18+
for g in data_loader
19+
g = g |> device
20+
n = g.num_graphs
21+
y = g.gdata.y
22+
= model(g, g.ndata.X) |> vec
2123
loss += logitbinarycrossentropy(ŷ, y) * n
2224
acc += mean((2 .*.- 1) .* (2 .* y .- 1) .> 0) * n
2325
ntot += n
24-
end
26+
end
2527
return (loss = round(loss/ntot, digits=4), acc = round(acc*100/ntot, digits=2))
2628
end
2729

28-
struct GNNData
29-
g
30-
X
31-
y
32-
end
33-
34-
Base.getindex(data::GNNData, i::Int) = getindex(data, [i])
35-
36-
function Base.getindex(data::GNNData, i::AbstractVector)
37-
sg, nodemap = subgraph(data.g, i)
38-
return (sg, data.X[:,nodemap], data.y[i])
39-
end
40-
41-
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
42-
Flux.Data._nobs(data::GNNData) = data.g.num_graphs
43-
Flux.Data._getobs(data::GNNData, i) = data[i]
44-
45-
function process_dataset(data)
46-
g = GNNGraph(data.source, data.target, num_nodes=data.num_nodes, graph_indicator=data.graph_indicator)
30+
function getdataset()
31+
data = TUDataset("MUTAG")
32+
4733
X = Array{Float32}(onehotbatch(data.node_labels, 0:6))
48-
# The dataset also has edge features but we won't be using them
49-
# E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
5034
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
5135
@assert all(([0,1]), y) # binary classification
52-
return GNNData(g, X, y)
36+
# The dataset also has edge features but we won't be using them
37+
E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
38+
39+
return GNNGraph(data.source, data.target,
40+
num_nodes=data.num_nodes,
41+
graph_indicator=data.graph_indicator,
42+
ndata=(; X), edata=(; E), gdata=(; y))
5343
end
5444

5545
# arguments for the `train` function
@@ -78,23 +68,25 @@ function train(; kws...)
7868

7969
# LOAD DATA
8070

71+
8172
NUM_TRAIN = 150
82-
full_data = TUDataset("MUTAG")
8373

74+
gfull = getdataset()
75+
8476
@info "MUTAG DATASET
85-
num_nodes: $(full_data.num_nodes)
86-
num_edges: $(full_data.num_edges)
87-
num_graphs: $(full_data.num_graphs)"
88-
89-
perm = randperm(full_data.num_graphs)
90-
dtrain = process_dataset(full_data[perm[1:NUM_TRAIN]])
91-
dtest = process_dataset(full_data[perm[NUM_TRAIN+1:end]])
92-
train_loader = DataLoader(dtrain, batchsize=args.batchsize, shuffle=true)
93-
test_loader = DataLoader(dtest, batchsize=args.batchsize, shuffle=false)
77+
num_nodes: $(gfull.num_nodes)
78+
num_edges: $(gfull.num_edges)
79+
num_graphs: $(gfull.num_graphs)"
80+
81+
perm = randperm(gfull.num_graphs)
82+
gtrain = getobs(gfull, perm[1:NUM_TRAIN])
83+
gtest = getobs(gfull, perm[NUM_TRAIN+1:end])
84+
train_loader = DataLoader(gtrain, batchsize=args.batchsize, shuffle=true)
85+
test_loader = DataLoader(gtest, batchsize=args.batchsize, shuffle=false)
9486

9587
# DEFINE MODEL
9688

97-
nin = size(dtrain.X, 1)
89+
nin = size(gtrain.ndata.X, 1)
9890
nhidden = args.nhidden
9991

10092
model = GNNChain(GraphConv(nin => nhidden, relu),
@@ -119,11 +111,11 @@ function train(; kws...)
119111

120112
report(0)
121113
for epoch in 1:args.epochs
122-
for (g, X, y) in train_loader
123-
g, X, y = g |> device, X |> device, y |> device
114+
for g in train_loader
115+
g = g |> device
124116
gs = Flux.gradient(ps) do
125-
= model(g, X) |> vec
126-
logitbinarycrossentropy(ŷ, y)
117+
= model(g, g.ndata.X) |> vec
118+
logitbinarycrossentropy(ŷ, g.gdata.y)
127119
end
128120
Flux.Optimise.update!(opt, ps, gs)
129121
end
@@ -132,4 +124,4 @@ function train(; kws...)
132124
end
133125
end
134126

135-
train()
127+
# train()

src/GraphNeuralNetworks.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ import KrylovKit
77
using Base: tail
88
using CUDA
99
using Flux
10-
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
10+
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
1111
using MacroTools: @forward
12+
import LearnBase
1213
using LearnBase: getobs
1314
using NNlib, NNlibCUDA
1415
using ChainRulesCore
@@ -26,6 +27,8 @@ export
2627

2728
# from LightGraphs
2829
adjacency_matrix,
30+
# from SparseArrays
31+
sprand, sparse,
2932

3033
# msgpass
3134
# update, update_edge, update_global, message, propagate,

src/gnngraph.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,17 @@ containing the total number of nodes and edges of the original graphs.
418418
Equivalent to [`SparseArrays.blockdiag`](@ref).
419419
"""
420420
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
421+
422+
### LearnBase compatibility
423+
LearnBase.nobs(g::GNNGraph) = g.num_graphs
424+
LearnBase.getobs(g::GNNGraph, i) = subgraph(g, i)[1]
425+
426+
# Flux's Dataloader compatibility. Related PR https://github.com/FluxML/Flux.jl/pull/1683
427+
Flux.Data._nobs(g::GNNGraph) = g.num_graphs
428+
Flux.Data._getobs(g::GNNGraph, i) = subgraph(g, i)[1]
429+
421430
#########################
431+
Base.:(==)(g1::GNNGraph, g2::GNNGraph) = all(k -> getfield(g1,k)==getfield(g2,k), fieldnames(typeof(g1)))
422432

423433
"""
424434
subgraph(g::GNNGraph, i)
@@ -432,7 +442,12 @@ The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
432442
"""
433443
subgraph(g::GNNGraph, i::Int) = subgraph(g::GNNGraph{<:COO_T}, [i])
434444

435-
function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
445+
function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector{Int})
446+
if g.graph_indicator === nothing
447+
@assert i == [1]
448+
return g
449+
end
450+
436451
node_mask = g.graph_indicator .∈ Ref(i)
437452

438453
nodes = (1:g.num_nodes)[node_mask]
@@ -446,8 +461,9 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
446461
s = [nodemap[i] for i in s[edge_mask]]
447462
t = [nodemap[i] for i in t[edge_mask]]
448463
w = isnothing(w) ? nothing : w[edge_mask]
464+
449465
ndata = getobs(g.ndata, node_mask)
450-
edata = getobs(g.ndata, edge_mask)
466+
edata = getobs(g.edata, edge_mask)
451467
gdata = getobs(g.gdata, i)
452468

453469
num_nodes = length(graph_indicator)
@@ -461,7 +477,6 @@ function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
461477
return gnew, nodes
462478
end
463479

464-
### TO DEPRECATE ?? ###
465480
function node_features(g::GNNGraph)
466481
if isempty(g.ndata)
467482
return nothing
@@ -491,7 +506,6 @@ function global_features(g::GNNGraph)
491506
return g.gdata[1]
492507
end
493508
end
494-
#########
495509

496510

497511
@non_differentiable normalized_laplacian(x...)

src/layers/pool.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,25 @@ Takes a graph and feature nodes as inputs
88
and performs the operation
99
1010
```math
11-
\mathbf{u}_V = \box_{i \in V} \mathbf{x}_i
12-
````
11+
\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i
12+
```
1313
where ``V`` is the set of nodes of the input graph and
14-
the type of aggregation represented by `\box` is selected by the `aggr` argument.
14+
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
1515
Commonly used aggregations are are `mean`, `max`, and `+`.
1616
1717
```julia
18-
using GraphNeuralNetworks, LightGraphs
18+
using Flux, GraphNeuralNetworks, LightGraphs
1919
2020
pool = GlobalPool(mean)
2121
22-
g = GNNGraph(random_regular_graph(10, 4))
22+
g = GNNGraph(erdos_renyi(10, 4))
2323
X = rand(32, 10)
2424
pool(g, X) # => 32x1 matrix
25+
26+
27+
g = Flux.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
28+
X = rand(32, 50)
29+
pool(g, X) # => 32x5 matrix
2530
```
2631
"""
2732
struct GlobalPool{F} <: GNNLayer

0 commit comments

Comments
 (0)