Skip to content

Commit 2020e58

Browse files
numerically stable GATv2Conv (#247)
1 parent edfdff3 commit 2020e58

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/layers/conv.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM
343343
Wx = l.dense_x(x)
344344
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
345345

346-
# a hand-writtent message passing
346+
# a hand-written message passing
347347
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
348348
α = softmax_edge_neighbors(g, m.logα)
349349
β = α .* m.Wxj
@@ -371,7 +371,7 @@ function message(l::GATConv, Wxi, Wxj, e)
371371
end
372372
aWW = sum(l.a .* Wxx, dims=1) # 1 × nheads × nedges
373373
logα = leakyrelu.(aWW, l.negative_slope)
374-
return(logα = logα, Wxj = Wxj)
374+
return (; logα, Wxj)
375375
end
376376

377377
function Base.show(io::IO, l::GATConv)
@@ -480,11 +480,13 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
480480
_, out = l.channel
481481
heads = l.heads
482482

483-
Wix = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
484-
Wjx = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes
483+
Wxi = reshape(l.dense_i(x), out, heads, :) # out × heads × nnodes
484+
Wxj = reshape(l.dense_j(x), out, heads, :) # out × heads × nnodes
485485

486-
m = propagate(message, g, +, l; xi=Wix, xj=Wjx, e) # out × heads × nnodes
487-
x = m.β ./ m.α
486+
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e)
487+
α = softmax_edge_neighbors(g, m.logα)
488+
β = α .* m.Wxj
489+
x = aggregate_neighbors(g, +, β)
488490

489491
if !l.concat
490492
x = mean(x, dims=2)
@@ -494,17 +496,16 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
494496
return x
495497
end
496498

497-
function message(l::GATv2Conv, Wix, Wjx, e)
499+
function message(l::GATv2Conv, Wxi, Wxj, e)
498500
_, out = l.channel
499501
heads = l.heads
500502

501-
Wx = Wix + Wjx # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
503+
Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
502504
if e !== nothing
503505
Wx += reshape(l.dense_e(e), out, heads, :)
504506
end
505-
eij = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims=1) # 1 × heads × nedges
506-
α = exp.(eij)
507-
return= α, β = α .* Wjx)
507+
logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims=1) # 1 × heads × nedges
508+
return (; logα, Wxj)
508509
end
509510

510511
function Base.show(io::IO, l::GATv2Conv)

0 commit comments

Comments
 (0)