Skip to content

Commit cb515a4

Browse files
Merge pull request #17 from CarloLucibello/cl/graphs
add support for batched graphs
2 parents b4a8675 + 3043cce commit cb515a4

File tree

6 files changed

+108
-55
lines changed

6 files changed

+108
-55
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ export
4747

4848
# layers/pool
4949
GlobalPool,
50-
LocalPool,
5150
TopKPool,
5251
topk_index
5352

src/gnngraph.jl

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212

1313
"""
14-
GNNGraph(data; [graph_type, dir, num_nodes, nf, ef, gf])
14+
GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
1515
GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717
A type representing a graph structure and storing also arrays
@@ -43,11 +43,13 @@ from the LightGraphs' graph library can be used on it.
4343
- `:dense`. A dense adjacency matrix representation.
4444
Default `:coo`.
4545
- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
46-
Possible values are `:out` and `:in`. Defaul `:out`.
47-
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing.
48-
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing.
49-
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing.
50-
- `gf`: Global features. Default nothing.
46+
Possible values are `:out` and `:in`. Default `:out`.
47+
- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
48+
- `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
49+
- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
50+
- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
51+
- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
52+
- `gf`: Global features. Default `nothing`.
5153
5254
# Usage.
5355
@@ -87,6 +89,8 @@ struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
8789
graph::T
8890
num_nodes::Int
8991
num_edges::Int
92+
num_graphs::Int
93+
graph_indicator
9094
nf
9195
ef
9296
gf
@@ -99,7 +103,9 @@ end
99103
@functor GNNGraph
100104

101105
function GNNGraph(data;
102-
num_nodes = nothing,
106+
num_nodes = nothing,
107+
num_graphs = 1,
108+
graph_indicator = nothing,
103109
graph_type = :coo,
104110
dir = :out,
105111
nf = nothing,
@@ -119,6 +125,9 @@ function GNNGraph(data;
119125
elseif graph_type == :sparse
120126
g, num_nodes, num_edges = to_sparse(data; dir)
121127
end
128+
if num_graphs > 1
129+
@assert len(graph_indicator) = num_nodes "When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
130+
end
122131

123132
## Possible future implementation of feature maps.
124133
## Currently this doesn't play well with zygote due to
@@ -127,8 +136,9 @@ function GNNGraph(data;
127136
# edata["e"] = ef
128137
# gdata["g"] = gf
129138

130-
131-
GNNGraph(g, num_nodes, num_edges, nf, ef, gf)
139+
GNNGraph(g, num_nodes, num_edges,
140+
num_graphs, graph_indicator,
141+
nf, ef, gf)
132142
end
133143

134144
# COO convenience constructors
@@ -147,7 +157,7 @@ function GNNGraph(g::GNNGraph;
147157
nf=node_feature(g), ef=edge_feature(g), gf=global_feature(g))
148158
# ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data
149159

150-
GNNGraph(g.graph, g.num_nodes, g.num_edges, nf, ef, gf) # ndata, edata, gdata,
160+
GNNGraph(g.graph, g.num_nodes, g.num_edges, g.num_graphs, g.graph_indicator, nf, ef, gf) # ndata, edata, gdata,
151161
end
152162

153163

@@ -370,6 +380,7 @@ function add_self_loops(g::GNNGraph{<:COO_T})
370380
t = [t; nodes]
371381

372382
GNNGraph((s, t, nothing), g.num_nodes, length(s),
383+
g.num_graphs, g.graph_indicator,
373384
node_feature(g), edge_feature(g), global_feature(g))
374385
end
375386

@@ -379,6 +390,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true)
379390
A += I
380391
num_edges = g.num_edges + g.num_nodes
381392
GNNGraph(A, g.num_nodes, num_edges,
393+
g.num_graphs, g.graph_indicator,
382394
node_feature(g), edge_feature(g), global_feature(g))
383395
end
384396

@@ -392,10 +404,46 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
392404
s = s[mask_old_loops]
393405
t = t[mask_old_loops]
394406

395-
GNNGraph((s, t, nothing), g.num_nodes, length(s),
407+
GNNGraph((s, t, nothing), g.num_nodes, length(s),
408+
g.num_graphs, g.graph_indicator,
396409
node_feature(g), edge_feature(g), global_feature(g))
397410
end
398411

412+
function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
413+
s1, t1 = edge_index(g1)
414+
s2, t2 = edge_index(g2)
415+
nv1, nv2 = g1.num_nodes, g2.num_nodes
416+
s = vcat(s1, nv1 .+ s2)
417+
t = vcat(t1, nv1 .+ t2)
418+
w = cat_features(edge_weight(g1), edge_weight(g2))
419+
420+
ind1 = isnothing(g1.graph_indicator) ? fill!(similar(s1, Int, nv1), 1) : g1.graph_indicator
421+
ind2 = isnothing(g2.graph_indicator) ? fill!(similar(s2, Int, nv2), 1) : g2.graph_indicator
422+
graph_indicator = vcat(ind1, g1.num_graphs .+ ind2)
423+
424+
GNNGraph(
425+
(s, t, w),
426+
nv1 + nv2, g1.num_edges + g2.num_edges,
427+
g1.num_graphs + g2.num_graphs, graph_indicator,
428+
cat_features(node_feature(g1), node_feature(g2)),
429+
cat_features(edge_feature(g1), edge_feature(g2)),
430+
cat_features(global_feature(g1), global_feature(g2)),
431+
)
432+
end
433+
434+
# Cat public interfaces
435+
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
436+
@assert length(gothers) >= 1
437+
g = g1
438+
for go in gothers
439+
g = _catgraphs(g, go)
440+
end
441+
return g
442+
end
443+
444+
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
445+
#########################
446+
399447
@non_differentiable normalized_laplacian(x...)
400448
@non_differentiable normalized_adjacency(x...)
401449
@non_differentiable scaled_laplacian(x...)

src/layers/pool.jl

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,42 @@
11
using DataStructures: nlargest
22

3-
"""
4-
GlobalPool(aggr, dim...)
5-
6-
Global pooling layer.
7-
8-
It pools all features with `aggr` operation.
9-
10-
# Arguments
11-
12-
- `aggr`: An aggregate function applied to pool all features.
13-
"""
14-
struct GlobalPool{A}
15-
aggr
16-
cluster::A
17-
function GlobalPool(aggr, dim...)
18-
cluster = ones(Int64, dim)
19-
new{typeof(cluster)}(aggr, cluster)
20-
end
21-
end
3+
@doc raw"""
4+
GlobalPool(aggr)
225
23-
(l::GlobalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster)
6+
Global pooling layer for graph neural networks.
7+
Takes a graph and feature nodes as inputs
8+
and performs the operation
249
25-
"""
26-
LocalPool(aggr, cluster)
10+
```math
11+
\mathbf{u}_V = \box_{i \in V} \mathbf{x}_i
12+
````
13+
where ``V`` is the set of nodes of the input graph and
14+
the type of aggregation represented by `\box` is selected by the `aggr` argument.
15+
Commonly used aggregations are are `mean`, `max`, and `+`.
2716
28-
Local pooling layer.
17+
```julia
18+
using GraphNeuralNetworks, LightGraphs
2919
30-
It pools features with `aggr` operation accroding to `cluster`. It is implemented with `scatter` operation.
20+
pool = GlobalPool(mean)
3121
32-
# Arguments
33-
34-
- `aggr`: An aggregate function applied to pool all features.
35-
- `cluster`: An index structure which indicates what features to aggregate with.
22+
g = GNNGraph(random_regular_graph(10, 4))
23+
X = rand(32, 10)
24+
pool(g, X) # => 32x1 matrix
25+
```
3626
"""
37-
struct LocalPool{A<:AbstractArray}
38-
aggr
39-
cluster::A
27+
struct GlobalPool{F}
28+
aggr::F
4029
end
4130

42-
(l::LocalPool)(X::AbstractArray) = NNlib.scatter(l.aggr, X, l.cluster)
31+
function (l::GlobalPool)(g::GNNGraph, X::AbstractArray)
32+
if isnothing(g.graph_indicator)
33+
# assume only one graph
34+
indexes = fill!(similar(X, Int, g.num_nodes), 1)
35+
else
36+
indexes = g.graph_indicator
37+
end
38+
return NNlib.scatter(l.aggr, X, indexes)
39+
end
4340

4441
"""
4542
TopKPool(adj, k, in_channel)

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ function sort_edge_index(u, v)
99
p = sortperm(uv) # isless lexicographically defined for tuples
1010
return u[p], v[p]
1111
end
12+
13+
cat_features(x1::Nothing, x2::Nothing) = nothing
14+
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims=ndims(x1))

test/gnngraph.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,16 @@
101101
@test adjacency_matrix(fg2) == A2
102102
@test fg2.num_edges == sum(A2)
103103
end
104+
105+
@testset "batch" begin
106+
g1 = GNNGraph(random_regular_graph(10,2), nf=rand(16,10))
107+
g2 = GNNGraph(random_regular_graph(4,2), nf=rand(16,4))
108+
g3 = GNNGraph(random_regular_graph(7,2), nf=rand(16,7))
109+
110+
g12 = Flux.batch([g1, g2])
111+
g12b = blockdiag(g1, g2)
112+
113+
g123 = Flux.batch([g1, g2, g3])
114+
@test g123.graph_indicator == [fill(1, 10); fill(2, 4); fill(3, 7)]
115+
end
104116
end

test/layers/pool.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
1-
cluster = [1 1 1 1; 2 2 3 3; 4 4 5 5]
2-
X = Array(reshape(1:24, 2, 3, 4))
3-
41
@testset "pool" begin
52
@testset "GlobalPool" begin
6-
glb_cltr = [1 1 1 1; 1 1 1 1; 1 1 1 1]
7-
p = GlobalPool(+, 3, 4)
8-
@test p(X) == NNlib.scatter(+, X, glb_cltr)
9-
end
10-
11-
@testset "LocalPool" begin
12-
p = LocalPool(+, cluster)
13-
@test p(X) == NNlib.scatter(+, X, cluster)
3+
n = 10
4+
X = rand(16, n)
5+
g = GNNGraph(random_regular_graph(n, 4))
6+
p = GlobalPool(+)
7+
@test p(g, X) NNlib.scatter(+, X, ones(Int, n))
148
end
159

1610
@testset "TopKPool" begin

0 commit comments

Comments
 (0)