Skip to content

Commit b9e8571

Browse files
duplicate edge features if needed
1 parent c3a33a1 commit b9e8571

File tree

6 files changed

+77
-26
lines changed

6 files changed

+77
-26
lines changed

docs/src/gnngraph.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,17 @@ g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
4545
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
4646

4747

48-
# Attach an array with edge features
48+
# Attach an array with edge features.
49+
# Since `GNNGraph`s are directed, the number of edges
50+
# will be double that of the original LightGraphs' undirected graph.
51+
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 60))
52+
@assert g.num_edges == 60
53+
54+
# If we pass only half of the edge features, they will be copied
55+
# on the reversed edges.
4956
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
5057

58+
5159
# Create a new graph from previous one, inheriting edge data
5260
# but replacing node data
5361
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ julia> for _ in 1:1000
3838
julia> gbatch = Flux.batch(all_graphs)
3939
GNNGraph:
4040
num_nodes = 10000
41-
num_edges = 20000
41+
num_edges = 40000
4242
num_graphs = 1000
4343
ndata:
4444
x => (16, 10000)

src/gnngraph.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,20 @@ function GNNGraph(data;
124124
@assert dir [:in, :out]
125125

126126
if graph_type == :coo
127-
g, num_nodes, num_edges = to_coo(data; num_nodes, dir)
127+
graph, num_nodes, num_edges = to_coo(data; num_nodes, dir)
128128
elseif graph_type == :dense
129-
g, num_nodes, num_edges = to_dense(data; dir)
129+
graph, num_nodes, num_edges = to_dense(data; dir)
130130
elseif graph_type == :sparse
131-
g, num_nodes, num_edges = to_sparse(data; dir)
131+
graph, num_nodes, num_edges = to_sparse(data; dir)
132132
end
133133

134134
num_graphs = !isnothing(graph_indicator) ? maximum(graph_indicator) : 1
135135

136-
ndata = normalize_graphdata(ndata, :x)
137-
edata = normalize_graphdata(edata, :e)
138-
gdata = normalize_graphdata(gdata, :u)
136+
ndata = normalize_graphdata(ndata, default_name=:x, n=num_nodes)
137+
edata = normalize_graphdata(edata, default_name=:e, n=num_edges, duplicate_if_needed=true)
138+
gdata = normalize_graphdata(gdata, default_name=:u, n=num_graphs)
139139

140-
GNNGraph(g,
140+
GNNGraph(graph,
141141
num_nodes, num_edges, num_graphs,
142142
graph_indicator,
143143
ndata, edata, gdata)
@@ -149,21 +149,22 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
149149

150150
# GNNGraph(g::AbstractGraph; kws...) = GNNGraph(adjacency_matrix(g, dir=:out); kws...)
151151

152-
function GNNGraph(g::AbstractGraph; kws...)
152+
function GNNGraph(g::AbstractGraph; edata=(;), kws...)
153153
s = LightGraphs.src.(LightGraphs.edges(g))
154154
t = LightGraphs.dst.(LightGraphs.edges(g))
155155
if !LightGraphs.is_directed(g)
156156
# add reverse edges since GNNGraph are directed
157-
s, t = [s; t], [t; s]
157+
s, t = [s; t], [t; s]
158158
end
159-
GNNGraph((s, t); num_nodes = LightGraphs.nv(g), kws...)
159+
GNNGraph((s, t); edata, num_nodes=LightGraphs.nv(g), kws...)
160160
end
161161

162+
162163
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata)
163164

164-
ndata = normalize_graphdata(ndata, :x)
165-
edata = normalize_graphdata(edata, :e)
166-
gdata = normalize_graphdata(gdata, :u)
165+
ndata = normalize_graphdata(ndata, default_name=:x, n=g.num_nodes)
166+
edata = normalize_graphdata(edata, default_name=:e, n=g.num_edges, duplicate_if_needed=true)
167+
gdata = normalize_graphdata(gdata, default_name=:u, n=g.num_graphs)
167168

168169
GNNGraph(g.graph,
169170
g.num_nodes, g.num_edges, g.num_graphs,

src/layers/conv.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
@doc raw"""
2-
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform)
2+
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true)
33
44
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
55
66
Performs the operation
77
```math
8-
\mathbf{x}'_i = \sum_{j\in \{i\} \cup N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
8+
\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
99
```
10-
where ``c_{ij} = \sqrt{(1+|N(i)|)(1+|N(j)|)}``.
10+
where ``c_{ij} = \sqrt{|N(i)||N(j)|}``.
1111
1212
The input to the layer is a node feature array `X`
1313
of size `(num_features, num_nodes)`.
@@ -19,37 +19,42 @@ of size `(num_features, num_nodes)`.
1919
- `σ`: Activation function.
2020
- `bias`: Add learnable bias.
2121
- `init`: Weights' initializer.
22+
- `add_self_loops`: Add self loops to the graph before performing the convolution.
2223
"""
2324
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
2425
weight::A
2526
bias::B
2627
σ::F
28+
add_self_loops::Bool
2729
end
2830

2931
@functor GCNConv
3032

3133
function GCNConv(ch::Pair{Int,Int}, σ=identity;
32-
init=glorot_uniform, bias::Bool=true)
34+
init=glorot_uniform, bias::Bool=true,
35+
add_self_loops=true)
3336
in, out = ch
3437
W = init(out, in)
3538
b = bias ? Flux.create_bias(W, true, out) : false
36-
GCNConv(W, b, σ)
39+
GCNConv(W, b, σ, add_self_loops)
3740
end
3841

3942
## Matrix operations are more performant,
40-
## but cannot compute the normalized laplacian of sparse cuda matrices yet,
43+
## but cannot compute the normalized adjacency of sparse cuda matrices yet,
4144
## therefore fallback to message passing framework on gpu for the time being
4245

4346
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
44-
= normalized_adjacency(g, T; dir=:out, add_self_loops=true)
47+
= normalized_adjacency(g, T; dir=:out, l.add_self_loops)
4548
l.σ.(l.weight * x *.+ l.bias)
4649
end
4750

4851
compute_message(l::GCNConv, xi, xj, eij) = xj
4952
update_node(l::GCNConv, m, x) = m
5053

5154
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
52-
g = add_self_loops(g)
55+
if l.add_self_loops
56+
g = add_self_loops(g)
57+
end
5358
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5459
x = x .* c'
5560
x, _ = propagate(l, g, +, x)

src/utils.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@ function cat_features(x1::NamedTuple, x2::NamedTuple)
2424
end
2525

2626
# Turns generic type into named tuple
27-
normalize_graphdata(data::NamedTuple, s::Symbol) = data
28-
normalize_graphdata(data::Nothing, s::Symbol) = NamedTuple()
29-
normalize_graphdata(data, s::Symbol) = NamedTuple{(s,)}((data,))
27+
normalize_graphdata(data::Nothing; kws...) = NamedTuple()
28+
29+
normalize_graphdata(data; default_name::Symbol, kws...) =
30+
normalize_graphdata(NamedTuple{(default_name,)}((data,)); default_name, kws...)
31+
32+
function normalize_graphdata(data::NamedTuple; default_name=:z, n, duplicate_if_needed=false)
33+
sz = map(x -> x isa AbstractArray ? size(x)[end] : 0, data)
34+
35+
if duplicate_if_needed # used to copy edge features on reverse edges
36+
@assert all(s -> s == 0 || s == n || s == n÷2, sz)
37+
38+
function replace(k, v)
39+
if v isa AbstractArray && size(v)[end] == n÷2
40+
v = cat(v, v, dims=ndims(v))
41+
end
42+
k => v
43+
end
44+
45+
data = NamedTuple(replace(k,v) for (k,v) in pairs(data))
46+
else
47+
@assert all(s -> s == 0 || s == n, sz)
48+
end
49+
return data
50+
end
51+

test/gnngraph.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,21 @@
182182
@test g.ndata.x2 2X
183183
@test g.edata.e2 2E
184184
@test g.gdata.u2 2U
185+
186+
# Dimension checks
187+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=rand(29), graph_type=GRAPH_T)
188+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=rand(2, 29), graph_type=GRAPH_T)
189+
@test_throws AssertionError GNNGraph(erdos_renyi(10, 30), edata=(; x=rand(30), y=rand(29)), graph_type=GRAPH_T)
190+
191+
# Copy features on reverse edge
192+
e = rand(30)
193+
g = GNNGraph(erdos_renyi(10, 30), edata=e, graph_type=GRAPH_T)
194+
@test g.edata.e == [e; e]
195+
196+
197+
# Attach non array data
198+
g = GNNGraph(erdos_renyi(10, 30), edata="ciao", graph_type=GRAPH_T)
199+
@test g.edata.e == "ciao"
185200
end
186201

187202
@testset "LearnBase and DataLoader compat" begin

0 commit comments

Comments
 (0)