Skip to content

Commit eed2197

Browse files
Merge pull request #43 from CarloLucibello/cl/dev
add SAGEConv + fix bug in GATConv
2 parents f0d8926 + 56acbeb commit eed2197

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ export
4646
GINConv,
4747
GraphConv,
4848
NNConv,
49+
SAGEConv,
4950

5051
# layers/pool
5152
GlobalPool,

src/layers/conv.jl

Lines changed: 73 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ end
142142

143143

144144
@doc raw"""
145-
GraphConv(in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
145+
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
146146
147147
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
148148
@@ -172,7 +172,7 @@ end
172172

173173
@functor GraphConv
174174

175-
function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
175+
function GraphConv(ch::Pair{Int,Int}, σ=identity; aggr=+,
176176
init=glorot_uniform, bias::Bool=true)
177177
in, out = ch
178178
W1 = init(out, in)
@@ -214,9 +214,9 @@ Implements the operation
214214
```math
215215
\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} W \mathbf{x}_j
216216
```
217-
where the attention coefficient ``\alpha_{ij}`` is given by
217+
where the attention coefficients ``\alpha_{ij}`` are given by
218218
```math
219-
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i || W \mathbf{x}_j]))
219+
\alpha_{ij} = \frac{1}{z_i} \exp(LeakyReLU(\mathbf{a}^T [W \mathbf{x}_i \,\|\, W \mathbf{x}_j]))
220220
```
221221
with ``z_i`` a normalization factor.
222222
@@ -225,9 +225,9 @@ with ``z_i`` a normalization factor.
225225
- `in`: The dimension of input features.
226226
- `out`: The dimension of output features.
227227
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
228-
- `heads`: Number attention heads
228+
- `heads`: Number attention heads.
229229
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
230-
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
230+
- `negative_slope`: The parameter of LeakyReLU.
231231
"""
232232
struct GATConv{T, A<:AbstractMatrix, B} <: GNNLayer
233233
weight::A
@@ -248,14 +248,18 @@ function GATConv(ch::Pair{Int,Int}, σ=identity;
248248
init=glorot_uniform, bias::Bool=true)
249249
in, out = ch
250250
W = init(out*heads, in)
251-
b = bias ? Flux.create_bias(W, true, out*heads) : false
251+
if concat
252+
b = bias ? Flux.create_bias(W, true, out*heads) : false
253+
else
254+
b = bias ? Flux.create_bias(W, true, out) : false
255+
end
252256
a = init(2*out, heads)
253257
negative_slope = convert(eltype(W), negative_slope)
254258
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
255259
end
256260

257261
function compute_message(l::GATConv, Wxi, Wxj)
258-
aWW = sum(l.a .* cat(Wxi, Wxj, dims=1), dims=1) # 1 × nheads × nedges
262+
aWW = sum(l.a .* vcat(Wxi, Wxj), dims=1) # 1 × nheads × nedges
259263
α = exp.(leakyrelu.(aWW, l.negative_slope))
260264
return= α, m = α .* Wxj)
261265
end
@@ -273,14 +277,13 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
273277

274278
x, _ = propagate(l, g, +, Wx) ## chout × nheads × nnodes
275279

276-
b = reshape(l.bias, chout, heads)
277-
x = l.σ.(x .+ b)
278280
if !l.concat
279-
x = sum(x, dims=2)
281+
x = mean(x, dims=2)
280282
end
283+
x = reshape(x, :, size(x, 3)) # return a matrix
284+
x = l.σ.(x .+ l.bias)
281285

282-
# We finally return a matrix
283-
return reshape(x, :, size(x, 3))
286+
return x
284287
end
285288

286289

@@ -514,3 +517,60 @@ function Base.show(io::IO, l::NNConv)
514517
print(io, ", aggr=", l.aggr)
515518
print(io, ")")
516519
end
520+
521+
522+
@doc raw"""
523+
SAGEConv(in => out, σ=identity; aggr=mean, bias=true, init=glorot_uniform)
524+
525+
GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf).
526+
527+
Performs:
528+
```math
529+
\mathbf{x}_i' = W [\mathbf{x}_i \,\|\, \square_{j \in \mathcal{N}(i)} \mathbf{x}_j]
530+
```
531+
532+
where the aggregation type is selected by `aggr`.
533+
534+
# Arguments
535+
536+
- `in`: The dimension of input features.
537+
- `out`: The dimension of output features.
538+
- `σ`: Activation function.
539+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
540+
- `bias`: Add learnable bias.
541+
- `init`: Weights' initializer.
542+
"""
543+
struct SAGEConv{A<:AbstractMatrix, B} <: GNNLayer
544+
weight::A
545+
bias::B
546+
σ
547+
aggr
548+
end
549+
550+
@functor SAGEConv
551+
552+
function SAGEConv(ch::Pair{Int,Int}, σ=identity; aggr=mean,
553+
init=glorot_uniform, bias::Bool=true)
554+
in, out = ch
555+
W = init(out, 2*in)
556+
b = bias ? Flux.create_bias(W, true, out) : false
557+
SAGEConv(W, b, σ, aggr)
558+
end
559+
560+
compute_message(l::SAGEConv, x_i, x_j, e_ij) = x_j
561+
update_node(l::SAGEConv, m, x) = l.σ.(l.weight * vcat(x, m) .+ l.bias)
562+
563+
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
564+
check_num_nodes(g, x)
565+
x, _ = propagate(l, g, l.aggr, x)
566+
x
567+
end
568+
569+
function Base.show(io::IO, l::SAGEConv)
570+
in_channel = size(l.weight1, ndims(l.weight1))
571+
out_channel = size(l.weight1, ndims(l.weight1)-1)
572+
print(io, "SAGEConv(", in_channel, " => ", out_channel)
573+
l.σ == identity || print(io, ", ", l.σ)
574+
print(io, ", aggr=", l.aggr)
575+
print(io, ")")
576+
end

test/layers/conv.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
@testset "GCNConv" begin
2828
l = GCNConv(in_channel => out_channel)
2929
for g in test_graphs
30-
test_layer(l, g, rtol=1e-5)
30+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
3131
end
3232

3333
l = GCNConv(in_channel => out_channel, tanh, bias=false)
3434
for g in test_graphs
35-
test_layer(l, g, rtol=1e-5)
35+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
3636
end
3737

3838
l = GCNConv(in_channel => out_channel, add_self_loops=false)
39-
test_layer(l, g1, rtol=1e-5)
39+
test_layer(l, g1, rtol=1e-5, outsize=(out_channel, g1.num_nodes))
4040
end
4141

4242
@testset "ChebConv" begin
@@ -65,12 +65,12 @@
6565
@testset "GraphConv" begin
6666
l = GraphConv(in_channel => out_channel)
6767
for g in test_graphs
68-
test_layer(l, g, rtol=1e-5)
68+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
6969
end
7070

71-
l = GraphConv(in_channel => out_channel, relu, bias=false)
71+
l = GraphConv(in_channel => out_channel, relu, bias=false, aggr=mean)
7272
for g in test_graphs
73-
test_layer(l, g, rtol=1e-5)
73+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
7474
end
7575

7676
@testset "bias=false" begin
@@ -81,10 +81,11 @@
8181

8282
@testset "GATConv" begin
8383

84-
for heads in (1, 2), concat in (true, false)
84+
for heads in (1, 3), concat in (true, false)
8585
l = GATConv(in_channel => out_channel; heads, concat)
8686
for g in test_graphs
87-
test_layer(l, g, rtol=1e-4)
87+
test_layer(l, g, rtol=1e-4,
88+
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
8889
end
8990
end
9091

@@ -100,14 +101,14 @@
100101
@test size(l.weight) == (out_channel, out_channel, num_layers)
101102

102103
for g in test_graphs
103-
test_layer(l, g, rtol=1e-5)
104+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
104105
end
105106
end
106107

107108
@testset "EdgeConv" begin
108109
l = EdgeConv(Dense(2*in_channel, out_channel), aggr=+)
109110
for g in test_graphs
110-
test_layer(l, g, rtol=1e-5)
111+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
111112
end
112113
end
113114

@@ -116,7 +117,7 @@
116117
eps = 0.001f0
117118
l = GINConv(nn, eps=eps)
118119
for g in test_graphs
119-
test_layer(l, g, rtol=1e-5, exclude_grad_fields=[:eps])
120+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps])
120121
end
121122

122123
@test !in(:eps, Flux.trainable(l))
@@ -129,13 +130,26 @@
129130
l = NNConv(in_channel => out_channel, nn)
130131
for g in test_graphs
131132
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
132-
test_layer(l, g, rtol=1e-5)
133+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
133134
end
134135

135136
l = NNConv(in_channel => out_channel, nn, tanh, bias=false, aggr=mean)
136137
for g in test_graphs
137138
g = GNNGraph(g, edata=rand(T, edim, g.num_edges))
138-
test_layer(l, g, rtol=1e-5)
139+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
140+
end
141+
end
142+
143+
@testset "SAGEConv" begin
144+
l = SAGEConv(in_channel => out_channel)
145+
@test l.aggr == mean
146+
for g in test_graphs
147+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
148+
end
149+
150+
l = SAGEConv(in_channel => out_channel, tanh, bias=false, aggr=+)
151+
for g in test_graphs
152+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
139153
end
140154
end
141155
end

test/test_utils.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ function FiniteDifferences.to_vec(x::Integer)
1212
return Int[x], Integer_from_vec
1313
end
1414

15-
function test_layer(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
16-
exclude_grad_fields=[],
17-
broken_grad_fields=[],
15+
function test_layer(l, g::GNNGraph; atol = 1e-7, rtol = 1e-5,
16+
exclude_grad_fields = [],
17+
broken_grad_fields =[],
1818
verbose = false,
1919
test_gpu = TEST_GPU,
20+
outsize = nothing,
2021
)
2122

2223
# TODO these give errors, probably some bugs in ChainRulesTestUtils
@@ -29,7 +30,7 @@ function test_layer(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
2930
x = node_features(g)
3031
e = edge_features(g)
3132

32-
x64, e64, l64, g64 = to64.([x, e, l, g])
33+
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
3334
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
3435

3536
f(l, g) = l(g)
@@ -45,7 +46,11 @@ function test_layer(l, g::GNNGraph; atol=1e-7, rtol=1e-5,
4546
# TEST OUTPUT
4647
y = f(l, g, x)
4748
@test eltype(y) == eltype(x)
48-
49+
@test all(isfinite, y)
50+
if !isnothing(outsize)
51+
@test size(y) == outsize
52+
end
53+
4954
g′ = f(l, g)
5055
@test g′.ndata.x y
5156

0 commit comments

Comments
 (0)