Skip to content

Commit 20d5b8a

Browse files
support edge_weight in GCNConv
1 parent 9264a68 commit 20d5b8a

File tree

4 files changed

+75
-34
lines changed

4 files changed

+75
-34
lines changed

src/GNNGraphs/query.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,23 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType=eltype(g.g
9595
return dir == :out ? A : A'
9696
end
9797

98-
function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=true)
99-
s, t = edge_index(g)
100-
98+
function _get_edge_weight(g, edge_weight)
10199
if edge_weight === true
102-
edge_weight = get_edge_weight(g)
100+
ew = get_edge_weight(g)
103101
elseif (edge_weight === false) || (edge_weight === nothing)
104-
edge_weight = nothing
102+
ew = nothing
105103
elseif edge_weight isa AbstractVector
106-
edge_weight = edge_weight
107-
else
104+
ew = edge_weight
105+
else
108106
error("Invalid edge_weight argument.")
109107
end
108+
return ew
109+
end
110+
111+
function Graphs.degree(g::GNNGraph{<:COO_T}, T=nothing; dir=:out, edge_weight=true)
112+
s, t = edge_index(g)
113+
114+
edge_weight = _get_edge_weight(g, edge_weight)
110115
edge_weight = isnothing(edge_weight) ? eltype(s)(1) : edge_weight
111116

112117
T = isnothing(T) ? eltype(edge_weight) : T

src/layers/conv.jl

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,81 @@
11
@doc raw"""
2-
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true)
2+
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform, add_self_loops=true, edge_weight=true)
33
44
Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907).
55
66
Performs the operation
77
```math
8-
\mathbf{x}'_i = \sum_{j\in N(i)} \frac{1}{c_{ij}} W \mathbf{x}_j
8+
\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j
99
```
10-
where ``c_{ij} = \sqrt{|N(i)||N(j)|}``.
10+
where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees.
1111
12-
The input to the layer is a node feature array `X`
13-
of size `(num_features, num_nodes)`.
12+
If the input graph has weighted edges and `edge_weight=true`, than ``c_{ij}`` will be computed as
13+
```math
14+
a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}}
15+
```
16+
17+
The input to the layer is a node feature array `X` of size `(num_features, num_nodes)`
18+
and optionally an edge weight vector.
1419
1520
# Arguments
1621
1722
- `in`: Number of input features.
1823
- `out`: Number of output features.
19-
- `σ`: Activation function.
20-
- `bias`: Add learnable bias.
21-
- `init`: Weights' initializer.
22-
- `add_self_loops`: Add self loops to the graph before performing the convolution.
24+
- `σ`: Activation function. Default `identity`.
25+
- `bias`: Add learnable bias. Default `true`.
26+
- `init`: Weights' initializer. Default `glorot_uniform`.
27+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
28+
- `edge_weight`. If `true`, consider the edge weights in the input graph (if available).
29+
Not compatible with `add_self_loops=true` at the moment. Default `true`.
2330
"""
2431
struct GCNConv{A<:AbstractMatrix, B, F} <: GNNLayer
2532
weight::A
2633
bias::B
2734
σ::F
2835
add_self_loops::Bool
36+
edge_weight::Bool
2937
end
3038

3139
@functor GCNConv
3240

3341
function GCNConv(ch::Pair{Int,Int}, σ=identity;
34-
init=glorot_uniform, bias::Bool=true,
35-
add_self_loops=true)
42+
init=glorot_uniform,
43+
bias::Bool=true,
44+
add_self_loops=true,
45+
edge_weight=false)
3646
in, out = ch
3747
W = init(out, in)
3848
b = bias ? Flux.create_bias(W, true, out) : false
39-
GCNConv(W, b, σ, add_self_loops)
49+
GCNConv(W, b, σ, add_self_loops, edge_weight)
4050
end
4151

42-
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
52+
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix)
53+
# Extract edge_weight from g if available and l.edge_weight == false,
54+
# otherwise return nothing.
55+
edge_weight = GNNGraphs._get_edge_weight(g, l.edge_weight) # vector or nothing
56+
return l(g, x, edge_weight)
57+
end
58+
59+
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW) where
60+
{T, EW<:Union{Nothing,AbstractVector}}
61+
4362
if l.add_self_loops
63+
@assert edge_weight === nothing
4464
g = add_self_loops(g)
4565
end
4666
Dout, Din = size(l.weight)
4767
if Dout < Din
4868
x = l.weight * x
4969
end
5070
# @assert all(>(0), degree(g, T, dir=:in))
51-
c = 1 ./ sqrt.(degree(g, T, dir=:in))
52-
x = x .* c'
53-
x = propagate(copy_xj, g, +, xj=x)
71+
c = 1 ./ sqrt.(degree(g, T; dir=:in, edge_weight))
5472
x = x .* c'
73+
if edge_weight === nothing
74+
x = propagate(copy_xj, g, +, xj=x)
75+
else
76+
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
77+
end
78+
x = x .* c'
5579
if Dout >= Din
5680
x = l.weight * x
5781
end

src/msgpass.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,18 @@ copy_xi(xi, xj, e) = xi
153153
"""
154154
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
155155

156+
"""
157+
e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj
158+
159+
Reshape `e` into broadcast compatible shape with `xj`
160+
(by prepending singleton dimensions) then perform
161+
broadcasted multiplication.
162+
"""
163+
function e_mul_xj(xi, xj::AbstractArray{Tj,Nj}, e::AbstractArray{Te,Ne}) where {Tj,Te, Nj, Ne}
164+
@assert Ne <= Nj
165+
e = reshape(e, ntuple(_ -> 1, Nj-Ne)..., size(e)...)
166+
return e .* xj
167+
end
156168

157169
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
158170
A = adjacency_matrix(g)

test/runtests.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets
2121
include("test_utils.jl")
2222

2323
tests = [
24-
"GNNGraphs/gnngraph",
25-
"GNNGraphs/transform",
26-
"GNNGraphs/operators",
27-
"GNNGraphs/generate",
28-
"GNNGraphs/query",
29-
"utils",
30-
"msgpass",
31-
"layers/basic",
24+
# "GNNGraphs/gnngraph",
25+
# "GNNGraphs/transform",
26+
# "GNNGraphs/operators",
27+
# "GNNGraphs/generate",
28+
# "GNNGraphs/query",
29+
# "utils",
30+
# "msgpass",
31+
# "layers/basic",
3232
"layers/conv",
33-
"layers/pool",
34-
"examples/node_classification_cora",
35-
"deprecations",
33+
# "layers/pool",
34+
# "examples/node_classification_cora",
35+
# "deprecations",
3636
]
3737

3838
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")

0 commit comments

Comments
 (0)