Skip to content

Commit 957ce36

Browse files
authored
Adds GATv2 layer (#97)
* Adds GATv2 layer
1 parent 4dec29c commit 957ce36

File tree

3 files changed

+126
-6
lines changed

3 files changed

+126
-6
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ export
5050
ChebConv,
5151
EdgeConv,
5252
GATConv,
53+
GATv2Conv,
5354
GatedGraphConv,
5455
GCNConv,
5556
GINConv,

src/layers/conv.jl

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where T
175175
check_num_nodes(g, X)
176176
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
177177

178-
= scaled_laplacian(g, eltype(X))
178+
= scaled_laplacian(g, eltype(X))
179179

180180
Z_prev = X
181181
Z = X *
@@ -333,9 +333,9 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
333333
x = mean(x, dims=2)
334334
end
335335
x = reshape(x, :, size(x, 3)) # return a matrix
336-
x = l.σ.(x .+ l.bias)
336+
x = l.σ.(x .+ l.bias)
337337

338-
return x
338+
return x
339339
end
340340

341341

@@ -346,6 +346,108 @@ function Base.show(io::IO, l::GATConv)
346346
print(io, "))")
347347
end
348348

349+
@doc raw"""
350+
GATv2Conv(in => out, σ=identity;
351+
heads=1,
352+
concat=true,
353+
init=glorot_uniform
354+
bias=true,
355+
negative_slope=0.2f0)
356+
357+
GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
358+
359+
Implements the operation
360+
```math
361+
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W_1 \mathbf{x}_j
362+
```
363+
where the attention coefficients ``\alpha_{ij}`` are given by
364+
```math
365+
\alpha_{ij} = \frac{1}{z_i} \exp(\mathbf{a}^T LeakyReLU([W_2 \mathbf{x}_i; W_1 \mathbf{x}_j]))
366+
```
367+
with ``z_i`` a normalization factor.
368+
369+
# Arguments
370+
371+
- `in`: The dimension of input features.
372+
- `out`: The dimension of output features.
373+
- `bias`: Learn the additive bias if true.
374+
- `heads`: Number attention heads.
375+
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
376+
- `negative_slope`: The parameter of LeakyReLU.
377+
"""
378+
struct GATv2Conv{T, A1, A2, B, C<:AbstractMatrix} <: GNNLayer
379+
dense_i::A1
380+
dense_j::A2
381+
bias::B
382+
a::C
383+
σ
384+
negative_slope::T
385+
channel::Pair{Int, Int}
386+
heads::Int
387+
concat::Bool
388+
end
389+
390+
@functor GATv2Conv
391+
Flux.trainable(l::GATv2Conv) = (l.dense_i, l.dense_j, l.bias, l.a)
392+
393+
function GATv2Conv(
394+
channel::Pair{Int,Int},
395+
σ=identity;
396+
heads::Int=1,
397+
concat::Bool=true,
398+
negative_slope=0.2,
399+
init=glorot_uniform,
400+
bias::Bool=true,
401+
)
402+
in, out = channel
403+
dense_i = Dense(in, out*heads; bias=bias, init=init)
404+
dense_j = Dense(in, out*heads; bias=false, init=init)
405+
if concat
406+
b = bias ? Flux.create_bias(dense_i.weight, bias, out*heads) : false
407+
else
408+
b = bias ? Flux.create_bias(dense_i.weight, bias, out) : false
409+
end
410+
a = init(out, heads)
411+
412+
negative_slope = convert(eltype(dense_i.weight), negative_slope)
413+
GATv2Conv(dense_i, dense_j, b, a, σ, negative_slope, channel, heads, concat)
414+
end
415+
416+
function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix)
417+
check_num_nodes(g, x)
418+
g = add_self_loops(g)
419+
in, out = l.channel
420+
heads = l.heads
421+
422+
Wix = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
423+
Wjx = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes
424+
425+
426+
function message(Wix, Wjx, e)
427+
eij = sum(l.a .* leakyrelu.(Wix + Wjx, l.negative_slope), dims=1) # 1 × heads × nedges
428+
α = exp.(eij)
429+
return= α, β = α .* Wjx)
430+
end
431+
432+
m = propagate(message, g, +; xi=Wix, xj=Wjx) # out × heads × nnodes
433+
x = m.β ./ m.α
434+
435+
if !l.concat
436+
x = mean(x, dims=2)
437+
end
438+
x = reshape(x, :, size(x, 3))
439+
x = l.σ.(x .+ l.bias)
440+
return x
441+
end
442+
443+
444+
function Base.show(io::IO, l::GATv2Conv)
445+
out, in = size(l.weight_i)
446+
print(io, "GATv2Conv(", in, "=>", out ÷ l.heads)
447+
print(io, ", LeakyReLU(λ=", l.negative_slope)
448+
print(io, "))")
449+
end
450+
349451

350452
@doc raw"""
351453
GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform)

test/layers/conv.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
T = Float32
66

77
adj1 = [0 1 0 1
8-
1 0 1 0
9-
0 1 0 1
10-
1 0 1 0]
8+
1 0 1 0
9+
0 1 0 1
10+
1 0 1 0]
1111

1212
g1 = GNNGraph(adj1,
1313
ndata=rand(T, in_channel, N),
@@ -109,6 +109,23 @@
109109
end
110110
end
111111

112+
@testset "GATv2Conv" begin
113+
114+
for heads in (1, 2), concat in (true, false)
115+
l = GATv2Conv(in_channel => out_channel; heads, concat)
116+
for g in test_graphs
117+
test_layer(l, g, rtol=1e-4,
118+
outsize=(concat ? heads*out_channel : out_channel, g.num_nodes))
119+
end
120+
end
121+
122+
@testset "bias=false" begin
123+
@test length(Flux.params(GATv2Conv(2=>3))) == 5
124+
@test length(Flux.params(GATv2Conv(2=>3, bias=false))) == 3
125+
end
126+
end
127+
128+
112129
@testset "GatedGraphConv" begin
113130
num_layers = 3
114131
l = GatedGraphConv(out_channel, num_layers)

0 commit comments

Comments
 (0)