Skip to content

Commit 157e78e

Browse files
authored
fixing overflow in GAT layers (#244)
* fixing overflow in GAT layers * one more small fix * an updated version of fix of GAT network
1 parent cef37c6 commit 157e78e

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/layers/conv.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,11 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM
343343
Wx = l.dense_x(x)
344344
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
345345

346-
m = propagate(message, g, +, l; xi=Wx, xj=Wx, e) ## chout × nheads × nnodes
347-
x = m.β ./ m.α
346+
# a hand-writtent message passing
347+
m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wx, Wx, e)
348+
α = softmax_edge_neighbors(g, m.logα)
349+
β = α .* m.Wxj
350+
x = aggregate_neighbors(g, +, β)
348351

349352
if !l.concat
350353
x = mean(x, dims=2)
@@ -367,8 +370,8 @@ function message(l::GATConv, Wxi, Wxj, e)
367370
Wxx = vcat(Wxi, Wxj, We)
368371
end
369372
aWW = sum(l.a .* Wxx, dims=1) # 1 × nheads × nedges
370-
α = exp.(leakyrelu.(aWW, l.negative_slope))
371-
return = α, β = α .* Wxj)
373+
logα = leakyrelu.(aWW, l.negative_slope)
374+
return(logα = logα, Wxj = Wxj)
372375
end
373376

374377
function Base.show(io::IO, l::GATConv)

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function softmax_edges(g::GNNGraph, e)
5454
num = exp.(e .- max_)
5555
den = reduce_edges(+, g, num)
5656
den = gather(den, gi)
57-
return num ./ den
57+
return num ./ (den .+ eps(eltype(e)))
5858
end
5959

6060
@doc raw"""

0 commit comments

Comments
 (0)