Skip to content

Commit f0d8926

Browse files
Merge pull request #42 from CarloLucibello/cl/dev
size checks add_self_loop option for GCNConv
2 parents c3a33a1 + c0cf5cd commit f0d8926

File tree

8 files changed

+102
-31
lines changed

8 files changed

+102
-31
lines changed

docs/src/gnngraph.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,17 @@ g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
4545
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
4646

4747

48-
# Attach an array with edge features
48+
# Attach an array with edge features.
49+
# Since `GNNGraph`s are directed, the number of edges
50+
# will be double that of the original LightGraphs' undirected graph.
51+
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 60))
52+
@assert g.num_edges == 60
53+
54+
# If we pass only half of the edge features, they will be copied
55+
# on the reversed edges.
4956
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
5057

58+
5159
# Create a new graph from previous one, inheriting edge data
5260
# but replacing node data
5361
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))
@@ -70,9 +78,9 @@ using Flux
7078
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
7179

7280
g23 = getgraph(gall, 2:3)
73-
@assert g23.num_graphs == 16
74-
@assert g23.num_nodes == 32
75-
@assert g23.num_edges == 60
81+
@assert g23.num_graphs == 2
82+
@assert g23.num_nodes == 20
83+
@assert g23.num_edges == 120 # 30 undirected edges x 2 graphs
7684

7785

7886
# DataLoader compatibility

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ julia> for _ in 1:1000
3838
julia> gbatch = Flux.batch(all_graphs)
3939
GNNGraph:
4040
num_nodes = 10000
41-
num_edges = 20000
41+
num_edges = 40000
4242
num_graphs = 1000
4343
ndata:
4444
x => (16, 10000)

src/gnngraph.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,20 @@ function GNNGraph(data;
124124
@assert dir [:in, :out]
125125

126126
if graph_type == :coo
127-
g, num_nodes, num_edges = to_coo(data; num_nodes, dir)
127+
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
128128
elseif graph_type == :dense
129-
g, num_nodes, num_edges = to_dense(data; dir)
129+
graph, num_nodes, num_edges = to_dense(data; dir)
130130
elseif graph_type == :sparse
131-
g, num_nodes, num_edges = to_sparse(data; dir)
131+
graph, num_nodes, num_edges = to_sparse(data; dir)
132132
end
133133

134134
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
135135

136-
ndata = normalize_graphdata(ndata, :x)
137-
edata = normalize_graphdata(edata, :e)
138-
gdata = normalize_graphdata(gdata, :u)
136+
ndata = normalize_graphdata(ndata, default_name=:x, n=num_nodes)
137+
edata = normalize_graphdata(edata, default_name=:e, n=num_edges, duplicate_if_needed=true)
138+
gdata = normalize_graphdata(gdata, default_name=:u, n=num_graphs)
139139

140-
GNNGraph(g,
140+
GNNGraph(graph,
141141
num_nodes, num_edges, num_graphs,
142142
graph_indicator,
143143
ndata, edata, gdata)
@@ -154,16 +154,17 @@ function GNNGraph(g::AbstractGraph; kws...)
154154
t = LightGraphs.dst.(LightGraphs.edges(g))
155155
if !LightGraphs.is_directed(g)
156156
# add reverse edges since GNNGraph are directed
157-
s, t = [s; t], [t; s]
157+
s, t = [s; t], [t; s]
158158
end
159-
GNNGraph((s, t); num_nodes = LightGraphs.nv(g), kws...)
159+
GNNGraph((s, t); num_nodes=LightGraphs.nv(g), kws...)
160160
end
161161

162+
162163
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata)
163164

164-
ndata = normalize_graphdata(ndata, :x)
165-
edata = normalize_graphdata(edata, :e)
166-
gdata = normalize_graphdata(gdata, :u)
165+
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
166+
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, duplicate_if_needed=true)
167+
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs)
167168

168169
GNNGraph(g.graph,
169170
g.num_nodes, g.num_edges, g.num_graphs,

src/layers/conv.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
@doc raw"""
2-
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform)
2+
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true)
33
44
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
55
66
Performs the operation
77
```math
8-
\mathbf{x}'_i = \sum_{j\in \{i\} \cup N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
8+
\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
99
```
10-
where ``c_{ij} = \sqrt{(1+|N(i)|)(1+|N(j)|)}``.
10+
where ``c_{ij} = \sqrt{|N(i)||N(j)|}``.
1111
1212
The input to the layer is a node feature array `X`
1313
of size `(num_features, num_nodes)`.
@@ -19,37 +19,42 @@ of size `(num_features, num_nodes)`.
1919
- `σ`: Activation function.
2020
- `bias`: Add learnable bias.
2121
- `init`: Weights' initializer.
22+
- `add_self_loops`: Add self loops to the graph before performing the convolution.
2223
"""
2324
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
2425
weight::A
2526
bias::B
2627
σ::F
28+
add_self_loops::Bool
2729
end
2830

2931
@functor GCNConv
3032

3133
function GCNConv(ch::Pair{Int,Int}, σ=identity;
32-
init=glorot_uniform, bias::Bool=true)
34+
init=glorot_uniform, bias::Bool=true,
35+
add_self_loops=true)
3336
in, out = ch
3437
W = init(out, in)
3538
b = bias ? Flux.create_bias(W, true, out) : false
36-
GCNConv(W, b, σ)
39+
GCNConv(W, b, σ, add_self_loops)
3740
end
3841

3942
## Matrix operations are more performant,
40-
## but cannot compute the normalized laplacian of sparse cuda matrices yet,
43+
## but cannot compute the normalized adjacency of sparse cuda matrices yet,
4144
## therefore fallback to message passing framework on gpu for the time being
4245

4346
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
44-
= normalized_adjacency(g, T; dir=:out, add_self_loops=true)
47+
= normalized_adjacency(g, T; dir=:out, l.add_self_loops)
4548
l.σ.(l.weight * x *.+ l.bias)
4649
end
4750

4851
compute_message(l::GCNConv, xi, xj, eij) = xj
4952
update_node(l::GCNConv, m, x) = m
5053

5154
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
52-
g = add_self_loops(g)
55+
if l.add_self_loops
56+
g = add_self_loops(g)
57+
end
5358
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5459
x = x .* c'
5560
x, _ = propagate(l, g, +, x)

src/utils.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,46 @@ function cat_features(x1::NamedTuple, x2::NamedTuple)
2424
end
2525

2626
# Turns generic type into named tuple
27-
normalize_graphdata(data::NamedTuple, s::Symbol) = data
28-
normalize_graphdata(data::Nothing, s::Symbol) = NamedTuple()
29-
normalize_graphdata(data, s::Symbol) = NamedTuple{(s,)}((data,))
27+
normalize_graphdata(data::Nothing; kws...) = NamedTuple()
28+
29+
normalize_graphdata(data; default_name::Symbol, kws...) =
30+
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
31+
32+
function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_needed=false)
33+
# This had to workaround two Zygote bugs with NamedTuples
34+
# https://github.com/FluxML/Zygote.jl/issues/1071
35+
# https://github.com/FluxML/Zygote.jl/issues/1072
36+
37+
if n == 1
38+
# If last array dimension is not 1, add a new dimension.
39+
# This is mostly usefule to reshape globale feature vectors
40+
# of size D to Dx1 matrices.
41+
function unsqz(v)
42+
if v isa AbstractArray && size(v)[end] != 1
43+
v = reshape(v, size(v)..., 1)
44+
end
45+
v
46+
end
47+
48+
data = NamedTuple{keys(data)}(unsqz.(values(data)))
49+
end
50+
51+
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
52+
53+
if duplicate_if_needed
54+
# Used to copy edge features on reverse edges
55+
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
56+
57+
function duplicate(v)
58+
if v isa AbstractArray && size(v)[end] == n÷2
59+
v = cat(v, v, dims=ndims(v))
60+
end
61+
v
62+
end
63+
64+
data = NamedTuple{keys(data)}(duplicate.(values(data)))
65+
else
66+
@assert all(s -> s == 0 || s == n, sz)
67+
end
68+
return data
69+
end

test/gnngraph.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@
182182
@test g.ndata.x2 2X
183183
@test g.edata.e2 2E
184184
@test g.gdata.u2 2U
185+
186+
# Dimension checks
187+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=rand(29), graph_type=GRAPH_T)
188+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=rand(2, 29), graph_type=GRAPH_T)
189+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=(; x=rand(30), y=rand(29)), graph_type=GRAPH_T)
190+
191+
# Copy features on reverse edge
192+
e = rand(30)
193+
g = GNNGraph(erdos_renyi(10, 30), edata=e, graph_type=GRAPH_T)
194+
@test g.edata.e == [e; e]
195+
196+
197+
# Attach non array data
198+
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
199+
@test g.edata.e == "ciao"
185200
end
186201

187202
@testset "LearnBase and DataLoader compat" begin

test/layers/conv.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
for g in test_graphs
3535
test_layer(l, g, rtol=1e-5)
3636
end
37-
end
3837

38+
l = GCNConv(in_channel => out_channel, add_self_loops=false)
39+
test_layer(l, g1, rtol=1e-5)
40+
end
3941

4042
@testset "ChebConv" begin
4143
k = 6

test/msgpass.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
101101
@test all(adjacency_matrix(g_) .== adj)
102102
@test size(node_features(g_)) == (2*out_channel, num_V)
103103
@test size(edge_features(g_)) == (out_channel, num_E)
104-
@test size(graph_features(g_)) == (in_channel,)
104+
@test size(graph_features(g_)) == (in_channel, 1)
105105
end
106106

107107
@testset "message and update with weights" begin
@@ -124,7 +124,7 @@ import GraphNeuralNetworks: compute_message, update_node, update_edge, propagate
124124
@test adjacency_matrix(g_) == adj
125125
@test size(node_features(g_)) == (out_channel, num_V)
126126
@test edge_features(g_) === E
127-
@test graph_features(g_) === U
127+
@test vec(graph_features(g_)) U
128128
end
129129

130130
@testset "NamedTuples" begin

0 commit comments

Comments
 (0)