@@ -343,7 +343,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM
343
343
Wx = l. dense_x (x)
344
344
Wx = reshape (Wx, chout, heads, :) # chout × nheads × nnodes
345
345
346
- # a hand-writtent message passing
346
+ # a hand-written message passing
347
347
m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wx, Wx, e)
348
348
α = softmax_edge_neighbors (g, m. logα)
349
349
β = α .* m. Wxj
@@ -371,7 +371,7 @@ function message(l::GATConv, Wxi, Wxj, e)
371
371
end
372
372
aWW = sum (l. a .* Wxx, dims= 1 ) # 1 × nheads × nedges
373
373
logα = leakyrelu .(aWW, l. negative_slope)
374
- return (logα = logα, Wxj = Wxj)
374
+ return (; logα, Wxj)
375
375
end
376
376
377
377
function Base. show (io:: IO , l:: GATConv )
@@ -480,11 +480,13 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
480
480
_, out = l. channel
481
481
heads = l. heads
482
482
483
- Wix = reshape (l. dense_i (x), out, heads, :) # out × heads × nnodes
484
- Wjx = reshape (l. dense_j (x), out, heads, :) # out × heads × nnodes
483
+ Wxi = reshape (l. dense_i (x), out, heads, :) # out × heads × nnodes
484
+ Wxj = reshape (l. dense_j (x), out, heads, :) # out × heads × nnodes
485
485
486
- m = propagate (message, g, + , l; xi= Wix, xj= Wjx, e) # out × heads × nnodes
487
- x = m. β ./ m. α
486
+ m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wxi, Wxj, e)
487
+ α = softmax_edge_neighbors (g, m. logα)
488
+ β = α .* m. Wxj
489
+ x = aggregate_neighbors (g, + , β)
488
490
489
491
if ! l. concat
490
492
x = mean (x, dims= 2 )
@@ -494,17 +496,16 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
494
496
return x
495
497
end
496
498
497
- function message (l:: GATv2Conv , Wix, Wjx , e)
499
+ function message (l:: GATv2Conv , Wxi, Wxj , e)
498
500
_, out = l. channel
499
501
heads = l. heads
500
502
501
- Wx = Wix + Wjx # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
503
+ Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
502
504
if e != = nothing
503
505
Wx += reshape (l. dense_e (e), out, heads, :)
504
506
end
505
- eij = sum (l. a .* leakyrelu .(Wx, l. negative_slope), dims= 1 ) # 1 × heads × nedges
506
- α = exp .(eij)
507
- return (α = α, β = α .* Wjx)
507
+ logα = sum (l. a .* leakyrelu .(Wx, l. negative_slope), dims= 1 ) # 1 × heads × nedges
508
+ return (; logα, Wxj)
508
509
end
509
510
510
511
function Base. show (io:: IO , l:: GATv2Conv )
0 commit comments