Skip to content

Commit ce10af3

Browse files
support self loops
1 parent 9b8bbf1 commit ce10af3

File tree

7 files changed

+102
-26
lines changed

7 files changed

+102
-26
lines changed

docs/src/api/messagepassing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ propagate
2424
copy_xi
2525
copy_xj
2626
xi_dot_xj
27+
e_mul_xj
2728
```

docs/src/gnngraph.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,25 @@ g.ndata.z
114114
g.edata.e
115115
```
116116

117+
## Edge weights
118+
119+
It is common to denote scalar edge features as edge weights. The `GNNGraph` has specific support
120+
for edge weights: they can be stored as part of internal representions of the graph (COO or adjacency matrix). Some graph convolutional layers, most notably the [`GCNConv`](@ref), can use the edge weights to perform weighted sums over the nodes' neighborhoods.
121+
122+
```julia
123+
julia> source = [1, 1, 2, 2, 3, 3];
124+
125+
julia> target = [2, 3, 1, 3, 1, 2];
126+
127+
julia> weight = [1.0, 0.5, 2.1, 2.3, 4, 4.1];
128+
129+
julia> g = GNNGraph(source, target, weight)
130+
GNNGraph:
131+
num_nodes = 3
132+
num_edges = 6
133+
134+
```
135+
117136
## Batches and Subgraphs
118137

119138
Multiple `GNNGraph`s can be batched togheter into a single graph

src/GNNGraphs/query.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=nodetype(g
113113
end
114114

115115
function _get_edge_weight(g, edge_weight)
116-
if edge_weight === true
116+
if edge_weight === true || edge_weight === nothing
117117
ew = get_edge_weight(g)
118-
elseif (edge_weight === false) || (edge_weight === nothing)
118+
elseif edge_weight === false
119119
ew = nothing
120120
elseif edge_weight isa AbstractVector
121121
ew = edge_weight
@@ -147,7 +147,7 @@ function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=tr
147147
s, t = edge_index(g)
148148

149149
edge_weight = _get_edge_weight(g, edge_weight)
150-
edge_weight = isnothing(edge_weight) ? eltype(s)(1) : edge_weight
150+
edge_weight = edge_weight === nothing ? eltype(s)(1) : edge_weight
151151

152152
T = isnothing(T) ? eltype(edge_weight) : T
153153
degs = fill!(similar(s, T, g.num_nodes), 0)
@@ -161,19 +161,20 @@ function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=tr
161161
end
162162

163163
function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T=nothing; dir=:out, edge_weight=true)
164+
# edge_weight=true or edge_weight=nothing act the same here
164165
@assert !(edge_weight isa AbstractArray) "passing the edge weights is not support by adjacency matrix representations"
165166
@assert dir (:in, :out, :both)
166167
if T === nothing
167168
Nt = nodetype(g)
168-
if ((edge_weight === false) || (edge_weight === nothing)) && !(Nt <: Integer)
169+
if edge_weight === false && !(Nt <: Integer)
169170
T = Nt == Float32 ? Int32 :
170171
Nt == Float16 ? Int16 : Int
171172
else
172173
T = Nt
173174
end
174175
end
175176
A = adjacency_matrix(g)
176-
if (edge_weight === false) || (edge_weight === nothing)
177+
if edge_weight === false
177178
A = map(>(0), A)
178179
end
179180
A = eltype(A) != T ? T.(A) : A

src/GNNGraphs/transform.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
Return a graph with the same features as `g`
66
but also adding edges connecting the nodes to themselves.
77
8-
Nodes with already existing
9-
self-loops will obtain a second self-loop.
8+
Nodes with already existing self-loops will obtain a second self-loop.
9+
10+
If the graphs has edge weights, the new edges will have weight 1.
1011
"""
1112
function add_self_loops(g::GNNGraph{<:COO_T})
1213
s, t = edge_index(g)
1314
@assert g.edata === (;)
14-
@assert get_edge_weight(g) === nothing
15+
ew = get_edge_weight(g)
1516
n = g.num_nodes
1617
nodes = convert(typeof(s), [1:n;])
1718
s = [s; nodes]
1819
t = [t; nodes]
20+
if ew !== nothing
21+
ew = [ew; fill!(similar(ew, n), 1)]
22+
end
1923

20-
GNNGraph((s, t, nothing),
24+
GNNGraph((s, t, ew),
2125
g.num_nodes, length(s), g.num_graphs,
2226
g.graph_indicator,
2327
g.ndata, g.edata, g.gdata)
@@ -340,7 +344,7 @@ function negative_sample(g::GNNGraph;
340344
@assert g.num_graphs == 1
341345
# Consider self-loops as positive edges
342346
# Construct new graph dropping features
343-
g = add_self_loops(GNNGraph(edge_index(g)))
347+
g = add_self_loops(GNNGraph(edge_index(g), num_nodes=g.num_nodes))
344348

345349
s, t = edge_index(g)
346350
n = g.num_nodes

src/layers/conv.jl

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@doc raw"""
2-
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true, edge_weight=true)
2+
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
33
44
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
55
@@ -9,7 +9,7 @@ Performs the operation
99
```
1010
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
1111
12-
If the input graph has weighted edges and `edge_weight=true`, than ``c_{ij}`` will be computed as
12+
If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as
1313
```math
1414
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
1515
```
@@ -25,15 +25,40 @@ and optionally an edge weight vector.
2525
- `bias`: Add learnable bias. Default `true`.
2626
- `init`: Weights' initializer. Default `glorot_uniform`.
2727
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
28-
- `edge_weight`. If `true`, consider the edge weights in the input graph (if available).
29-
Not compatible with `add_self_loops=true` at the moment. Default `true`.
28+
- `use_edge_weight`. If `true`, consider the edge weights in the input graph (if available).
29+
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
30+
31+
# Examples
32+
33+
```julia
34+
# create data
35+
s = [1,1,2,3]
36+
t = [2,3,1,1]
37+
g = GNNGraph(s, t)
38+
x = randn(3, g.num_nodes)
39+
40+
# create layer
41+
l = GCNConv(3 => 5)
42+
43+
# forward pass
44+
y = l(g, x) # size: 5 × num_nodes
45+
46+
# convolution with edge weights
47+
w = [1.1, 0.1, 2.3, 0.5]
48+
y = l(g, x, w)
49+
50+
# Edge weights can also be embedded in the graph.
51+
g = GNNGraph(s, t, w)
52+
l = GCNConv(3 => 5, use_edge_weight=true)
53+
y = l(g, x) # same as l(g, x, w)
54+
```
3055
"""
3156
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
3257
weight::A
3358
bias::B
3459
σ::F
3560
add_self_loops::Bool
36-
edge_weight::Bool
61+
use_edge_weight::Bool
3762
end
3863

3964
@functor GCNConv
@@ -42,46 +67,56 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
4267
init=glorot_uniform,
4368
bias::Bool=true,
4469
add_self_loops=true,
45-
edge_weight=false)
70+
use_edge_weight=false)
4671
in, out = ch
4772
W = init(out, in)
4873
b = bias ? Flux.create_bias(W, true, out) : false
49-
GCNConv(W, b, σ, add_self_loops, edge_weight)
74+
GCNConv(W, b, σ, add_self_loops, use_edge_weight)
5075
end
5176

52-
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix)
53-
# Extract edge_weight from g if available and l.edge_weight == false,
77+
function (l::GCNConv)(g::GNNGraph{<:COO_T}, x::AbstractMatrix)
78+
# Extract edge_weight from g if available and l.edge_weight == true,
5479
# otherwise return nothing.
55-
edge_weight = GNNGraphs._get_edge_weight(g, l.edge_weight) # vector or nothing
80+
edge_weight = GNNGraphs._get_edge_weight(g, l.use_edge_weight) # vector or nothing
81+
return l(g, x, edge_weight)
82+
end
83+
84+
function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix)
85+
edge_weight = nothing
5686
return l(g, x, edge_weight)
5787
end
5888

5989
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW) where
6090
{T, EW<:Union{Nothing,AbstractVector}}
6191

6292
if l.add_self_loops
63-
@assert edge_weight === nothing
6493
g = add_self_loops(g)
94+
if edge_weight !== nothing
95+
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
96+
@assert length(edge_weight) == g.num_edges
97+
end
6598
end
6699
Dout, Din = size(l.weight)
67100
if Dout < Din
101+
# multiply before convolution if it is more convenient, otherwise multiply after
68102
x = l.weight * x
69103
end
70-
# @assert all(>(0), degree(g, T, dir=:in))
71-
c = 1 ./ sqrt.(degree(g, T; dir=:in, edge_weight))
104+
d = degree(g, T; dir=:in, edge_weight)
105+
c = 1 ./ sqrt.(d)
72106
x = x .* c'
73107
if edge_weight === nothing
74108
x = propagate(copy_xj, g, +, xj=x)
75109
else
76110
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
77111
end
78-
x = x .* c'
112+
x = x .* c'
79113
if Dout >= Din
80114
x = l.weight * x
81115
end
82116
return l.σ.(x .+ l.bias)
83117
end
84118

119+
85120
function Base.show(io::IO, l::GCNConv)
86121
out, in = size(l.weight)
87122
print(io, "GCNConv($in => $out")

test/GNNGraphs/query.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,14 @@
6060
eweight = [0.1, 2.1, 1.2, 1]
6161
g = GNNGraph((s, t, eweight), graph_type=GRAPH_T)
6262
@test degree(g) == [2.2, 1.2, 1.0, 0.0]
63+
@test degree(g, edge_weight=nothing) == degree(g)
6364
d = degree(g, edge_weight=false)
6465
if GRAPH_T == :coo
6566
@test d == [2, 1, 1, 0]
66-
@test degree(g, edge_weight=nothing) == [2, 1, 1, 0]
6767
else
6868
# Adjacency matrix representation cannot disambiguate multiple edges
6969
# and edge weights
7070
@test d == [1, 1, 1, 0]
71-
@test degree(g, edge_weight=nothing) == [1, 1, 1, 0]
7271
end
7372
@test eltype(d) <: Integer
7473
if GRAPH_T == :coo

test/layers/conv.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@
3737

3838
l = GCNConv(in_channel => out_channel, add_self_loops=false)
3939
test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes))
40+
41+
@testset "edge weights" begin
42+
s = [2,3,1,3,1,2]
43+
t = [1,1,2,2,3,3]
44+
w = [1,2,3,4,5,6]
45+
g = GNNGraph((s, t, w), graph_type=GRAPH_T)
46+
x = ones(1, g.num_nodes)
47+
l = GCNConv(1 => 1, add_self_loops=false, use_edge_weight=true)
48+
l.weight .= 1
49+
d = degree(g, dir=:in)
50+
y = l(g, x)
51+
@test y[1,1] w[1] / (d[1]*d[2]) + w[2] / (d[1]*d[3])
52+
@test y[1,2] w[3] / (d[2]*d[1]) + w[4] / (d[2]*d[3])
53+
if GRAPH_T == :coo
54+
@test y l(g, x, w)
55+
end
56+
end
4057
end
4158

4259
@testset "ChebConv" begin

0 commit comments

Comments
 (0)