@@ -175,7 +175,7 @@ function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where T
175175 check_num_nodes (g, X)
176176 @assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
177177
178- L̃ = scaled_laplacian (g, eltype (X))
178+ L̃ = scaled_laplacian (g, eltype (X))
179179
180180 Z_prev = X
181181 Z = X * L̃
@@ -333,9 +333,9 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
333333 x = mean (x, dims= 2 )
334334 end
335335 x = reshape (x, :, size (x, 3 )) # return a matrix
336- x = l. σ .(x .+ l. bias)
336+ x = l. σ .(x .+ l. bias)
337337
338- return x
338+ return x
339339end
340340
341341
@@ -346,6 +346,108 @@ function Base.show(io::IO, l::GATConv)
346346 print (io, " ))" )
347347end
348348
349+ @doc raw """
350+ GATv2Conv(in => out, σ=identity;
351+ heads=1,
352+ concat=true,
353+ init=glorot_uniform
354+ bias=true,
355+ negative_slope=0.2f0)
356+
357+ GATv2 attentional layer from the paper [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491).
358+
359+ Implements the operation
360+ ```math
361+ \m athbf{x}_i' = \s um_{j \i n N(i) \c up \{ i\} } \a lpha_{ij} W_1 \m athbf{x}_j
362+ ```
363+ where the attention coefficients ``\a lpha_{ij}`` are given by
364+ ```math
365+ \a lpha_{ij} = \f rac{1}{z_i} \e xp(\m athbf{a}^T LeakyReLU([W_2 \m athbf{x}_i; W_1 \m athbf{x}_j]))
366+ ```
367+ with ``z_i`` a normalization factor.
368+
369+ # Arguments
370+
371+ - `in`: The dimension of input features.
372+ - `out`: The dimension of output features.
373+ - `bias`: Learn the additive bias if true.
374+ - `heads`: Number attention heads.
375+ - `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
376+ - `negative_slope`: The parameter of LeakyReLU.
377+ """
378+ struct GATv2Conv{T, A1, A2, B, C<: AbstractMatrix } <: GNNLayer
379+ dense_i:: A1
380+ dense_j:: A2
381+ bias:: B
382+ a:: C
383+ σ
384+ negative_slope:: T
385+ channel:: Pair{Int, Int}
386+ heads:: Int
387+ concat:: Bool
388+ end
389+
390+ @functor GATv2Conv
391+ Flux. trainable (l:: GATv2Conv ) = (l. dense_i, l. dense_j, l. bias, l. a)
392+
393+ function GATv2Conv (
394+ channel:: Pair{Int,Int} ,
395+ σ= identity;
396+ heads:: Int = 1 ,
397+ concat:: Bool = true ,
398+ negative_slope= 0.2 ,
399+ init= glorot_uniform,
400+ bias:: Bool = true ,
401+ )
402+ in, out = channel
403+ dense_i = Dense (in, out* heads; bias= bias, init= init)
404+ dense_j = Dense (in, out* heads; bias= false , init= init)
405+ if concat
406+ b = bias ? Flux. create_bias (dense_i. weight, bias, out* heads) : false
407+ else
408+ b = bias ? Flux. create_bias (dense_i. weight, bias, out) : false
409+ end
410+ a = init (out, heads)
411+
412+ negative_slope = convert (eltype (dense_i. weight), negative_slope)
413+ GATv2Conv (dense_i, dense_j, b, a, σ, negative_slope, channel, heads, concat)
414+ end
415+
416+ function (l:: GATv2Conv )(g:: GNNGraph , x:: AbstractMatrix )
417+ check_num_nodes (g, x)
418+ g = add_self_loops (g)
419+ in, out = l. channel
420+ heads = l. heads
421+
422+ Wix = reshape (l. dense_i (x), out, heads, :) # out × heads × nnodes
423+ Wjx = reshape (l. dense_j (x), out, heads, :) # out × heads × nnodes
424+
425+
426+ function message (Wix, Wjx, e)
427+ eij = sum (l. a .* leakyrelu .(Wix + Wjx, l. negative_slope), dims= 1 ) # 1 × heads × nedges
428+ α = exp .(eij)
429+ return (α = α, β = α .* Wjx)
430+ end
431+
432+ m = propagate (message, g, + ; xi= Wix, xj= Wjx) # out × heads × nnodes
433+ x = m. β ./ m. α
434+
435+ if ! l. concat
436+ x = mean (x, dims= 2 )
437+ end
438+ x = reshape (x, :, size (x, 3 ))
439+ x = l. σ .(x .+ l. bias)
440+ return x
441+ end
442+
443+
444+ function Base. show (io:: IO , l:: GATv2Conv )
445+ out, in = size (l. weight_i)
446+ print (io, " GATv2Conv(" , in, " =>" , out ÷ l. heads)
447+ print (io, " , LeakyReLU(λ=" , l. negative_slope)
448+ print (io, " ))" )
449+ end
450+
349451
350452@doc raw """
351453 GatedGraphConv(out, num_layers; aggr=+, init=glorot_uniform)
0 commit comments