Skip to content

Commit 15d230b

Browse files
add subgraph
1 parent 11d7039 commit 15d230b

File tree

7 files changed

+169
-17
lines changed

7 files changed

+169
-17
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14-
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1514
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1615
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1716
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

examples/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[deps]
2+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
5+
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
6+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# An example of semi-supervised node classification
2+
3+
using Flux
4+
using Flux: @functor, dropout, onecold, onehotbatch
5+
using Flux.Losses: logitbinarycrossentropy
6+
using GraphNeuralNetworks
7+
using MLDatasets: TUDataset
8+
using Statistics, Random
9+
using CUDA
10+
CUDA.allowscalar(false)
11+
12+
function eval_loss_accuracy(model, g, X, y)
13+
= model(g, X) |> vec
14+
l = logitbinarycrossentropy(ŷ, y)
15+
acc = mean((2 .*.- 1) .* (2 .* y .- 1) .> 0)
16+
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
17+
end
18+
19+
struct GNNData
20+
g
21+
X
22+
y
23+
end
24+
25+
26+
function getdataset(idxs)
27+
data = TUDataset("MUTAG")[idxs]
28+
@info "MUTAG: num_nodes: $(data.num_nodes) num_edges: $(data.num_edges) num_graphs: $(data.num_graphs)"
29+
g = GNNGraph(data.source, data.target, num_nodes=data.num_nodes, graph_indicator=data.graph_indicator)
30+
X = Array{Float32}(onehotbatch(data.node_labels, 0:6))
31+
# E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
32+
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
33+
@assert all(([0,1]), y) # binary classification
34+
return GNNData(g, X, y)
35+
end
36+
37+
# arguments for the `train` function
38+
Base.@kwdef mutable struct Args
39+
η = 1f-3 # learning rate
40+
epochs = 1000 # number of epochs
41+
seed = 17 # set seed > 0 for reproducibility
42+
use_cuda = false # if true use cuda (if available)
43+
nhidden = 128 # dimension of hidden features
44+
infotime = 10 # report every `infotime` epochs
45+
end
46+
47+
function train(; kws...)
48+
args = Args(; kws...)
49+
args.seed > 0 && Random.seed!(args.seed)
50+
51+
if args.use_cuda && CUDA.functional()
52+
device = gpu
53+
args.seed > 0 && CUDA.seed!(args.seed)
54+
@info "Training on GPU"
55+
else
56+
device = cpu
57+
@info "Training on CPU"
58+
end
59+
60+
# LOAD DATA
61+
62+
permindx = randperm(188)
63+
ntrain = 150
64+
gtrain, Xtrain, ytrain = getdataset(permindx[1:ntrain])
65+
gtest, Xtest, ytest = getdataset(permindx[ntrain+1:end])
66+
67+
# DEFINE MODEL
68+
69+
nin = size(Xtrain,1)
70+
nhidden = args.nhidden
71+
72+
model = GNNChain(GCNConv(nin => nhidden, relu),
73+
Dropout(0.5),
74+
GCNConv(nhidden => nhidden, relu),
75+
GlobalPool(mean),
76+
Dense(nhidden, 1)) |> device
77+
78+
ps = Flux.params(model)
79+
opt = ADAM(args.η)
80+
81+
82+
# LOGGING FUNCTION
83+
84+
function report(epoch)
85+
train = eval_loss_accuracy(model, gtrain, Xtrain, ytrain)
86+
test = eval_loss_accuracy(model, gtest, Xtest, ytest)
87+
println("Epoch: $epoch Train: $(train) Test: $(test)")
88+
end
89+
90+
# TRAIN
91+
92+
report(0)
93+
for epoch in 1:args.epochs
94+
# for (g, X, y) in train_loader
95+
gs = Flux.gradient(ps) do
96+
= model(gtrain, Xtrain) |> vec
97+
logitbinarycrossentropy(ŷ, ytrain)
98+
end
99+
Flux.Optimise.update!(opt, ps, gs)
100+
# end
101+
102+
epoch % args.infotime == 0 && report(epoch)
103+
end
104+
end
105+
106+
# train()
File renamed without changes.

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ export
2424
edge_index,
2525
node_feature, edge_feature, global_feature,
2626
adjacency_list, normalized_laplacian, scaled_laplacian,
27-
add_self_loops,
27+
add_self_loops, remove_self_loops,
28+
subgraph,
2829

2930
# from LightGraphs
3031
adjacency_matrix,

src/gnngraph.jl

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212

1313
"""
14-
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
14+
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
1515
GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717
A type representing a graph structure and storing also arrays
@@ -50,7 +50,6 @@ from the LightGraphs' graph library can be used on it.
5050
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
5151
Possible values are `:out` and `:in`. Default `:out`.
5252
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
53-
- `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
5453
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
5554
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
5655
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
@@ -123,17 +122,17 @@ function GNNGraph(data;
123122

124123
@assert graph_type [:coo, :dense, :sparse] "Invalid graph_type $graph_type requested"
125124
@assert dir [:in, :out]
125+
126126
if graph_type == :coo
127127
g, num_nodes, num_edges = to_coo(data; num_nodes, dir)
128128
elseif graph_type == :dense
129129
g, num_nodes, num_edges = to_dense(data; dir)
130130
elseif graph_type == :sparse
131131
g, num_nodes, num_edges = to_sparse(data; dir)
132132
end
133-
if num_graphs > 1
134-
@assert len(graph_indicator) = num_nodes "When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
135-
end
136-
133+
134+
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
135+
137136
## Possible future implementation of feature maps.
138137
## Currently this doesn't play well with zygote due to
139138
## https://github.com/FluxML/Zygote.jl/issues/717
@@ -154,8 +153,8 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
154153

155154
function GNNGraph(g::AbstractGraph; kws...)
156155
s = LightGraphs.src.(LightGraphs.edges(g))
157-
t = LightGraphs.dst.(LightGraphs.edges(g))
158-
GNNGraph((s, t); kws...)
156+
t = LightGraphs.dst.(LightGraphs.edges(g))
157+
GNNGraph((s, t); num_nodes = nv(g), kws...)
159158
end
160159

161160
function GNNGraph(g::GNNGraph;
@@ -436,36 +435,77 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
436435
)
437436
end
438437

439-
# Cat public interfaces
438+
### Cat public interfaces #############
440439

441-
```
440+
"""
442441
blockdiag(xs::GNNGraph...)
443442
444443
Batch togheter multiple `GNNGraph`s into a single one
445444
containing the total number of nodes and edges of the original graphs.
446445
447446
Equivalent to [`Flux.batch`](@ref).
448-
```
447+
"""
449448
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
450-
@assert length(gothers) >= 1
451449
g = g1
452450
for go in gothers
453451
g = _catgraphs(g, go)
454452
end
455453
return g
456454
end
457455

458-
```
456+
"""
459457
batch(xs::Vector{<:GNNGraph})
460458
461459
Batch togheter multiple `GNNGraph`s into a single one
462460
containing the total number of nodes and edges of the original graphs.
463461
464462
Equivalent to [`SparseArrays.blockdiag`](@ref).
465-
```
463+
"""
466464
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
467465
#########################
468466

467+
"""
468+
subgraph(g::GNNGraph, i)
469+
470+
Return the subgraph of `g` induced by those nodes `v`
471+
for which `g.graph_indicator[v] ∈ i`. In other words, it
472+
extract the component graphs from a batched graph.
473+
474+
It also returns a vector `nodes` mapping the new nodes to the old ones.
475+
The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
476+
"""
477+
subgraph(g::GNNGraph, i::Int) = subgraph(g::GNNGraph{<:COO_T}, [i])
478+
479+
function subgraph(g::GNNGraph{<:COO_T}, i::AbstractVector)
480+
node_mask = g.graph_indicator .∈ Ref(i)
481+
482+
nodes = (1:g.num_nodes)[node_mask]
483+
nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes))
484+
485+
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
486+
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
487+
488+
s, t, w = g.graph
489+
edge_mask = s .∈ Ref(nodes)
490+
s = [nodemap[i] for i in s[edge_mask]]
491+
t = [nodemap[i] for i in t[edge_mask]]
492+
w = isnothing(w) ? nothing : w[edge_mask]
493+
@show size(g.nf) size(node_mask)
494+
nf = isnothing(g.nf) ? nothing : g.nf[:,node_mask]
495+
ef = isnothing(g.ef) ? nothing : g.ef[:,edge_mask]
496+
gf = isnothing(g.gf) ? nothing : g.gf[:,i]
497+
498+
num_nodes = length(graph_indicator)
499+
num_edges = length(s)
500+
num_graphs = length(i)
501+
502+
gnew = GNNGraph((s,t,w),
503+
num_nodes, num_edges, num_graphs,
504+
graph_indicator,
505+
nf, ef, gf)
506+
return gnew, nodes
507+
end
508+
469509
@non_differentiable normalized_laplacian(x...)
470510
@non_differentiable normalized_adjacency(x...)
471511
@non_differentiable scaled_laplacian(x...)

src/layers/pool.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ X = rand(32, 10)
2424
pool(g, X) # => 32x1 matrix
2525
```
2626
"""
27-
struct GlobalPool{F}
27+
struct GlobalPool{F} <: GNNLayer
2828
aggr::F
2929
end
3030

0 commit comments

Comments
 (0)