@@ -978,14 +978,14 @@ function Base.show(io::IO, l::CGConv)
978
978
end
979
979
980
980
@doc raw """
981
- AGNNConv(init_beta=1f0 )
981
+ AGNNConv(; init_beta=1.0f0, trainable=true, add_self_loops=true )
982
982
983
983
Attention-based Graph Neural Network layer from paper [Attention-based
984
984
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
985
985
986
986
The forward pass is given by
987
987
```math
988
- \m athbf{x}_i' = \s um_{j \i n { N(i) \c up \{ i \} }} \a lpha_{ij} W \m athbf{x}_j
988
+ \m athbf{x}_i' = \s um_{j \i n N(i)} \a lpha_{ij} \m athbf{x}_j
989
989
```
990
990
where the attention coefficients ``\a lpha_{ij}`` are given by
991
991
```math
@@ -997,32 +997,40 @@ with the cosine distance defined by
997
997
\c os(\m athbf{x}_i, \m athbf{x}_j) =
998
998
\f rac{\m athbf{x}_i \c dot \m athbf{x}_j}{\l Vert\m athbf{x}_i\r Vert \l Vert\m athbf{x}_j\r Vert}
999
999
```
1000
- and ``\b eta`` a trainable parameter.
1000
+ and ``\b eta`` a trainable parameter if `trainable=true` .
1001
1001
1002
1002
# Arguments
1003
1003
1004
- - `init_beta`: The initial value of ``\b eta``.
1004
+ - `init_beta`: The initial value of ``\b eta``. Default 1.0f0.
1005
+ - `trainable`: If true, ``\b eta`` is trainable. Default `true`.
1006
+ - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
1005
1007
"""
1006
1008
struct AGNNConv{A <: AbstractVector } <: GNNLayer
1007
1009
β:: A
1010
+ add_self_loops:: Bool
1011
+ trainable:: Bool
1008
1012
end
1009
1013
1010
1014
@functor AGNNConv
1011
1015
1012
- function AGNNConv (init_beta = 1.0f0 )
1013
- AGNNConv ([init_beta])
1016
+ Flux. trainable (l:: AGNNConv ) = l. trainable ? (; l. β) : (;)
1017
+
1018
+ function AGNNConv (; init_beta = 1.0f0 , add_self_loops = true , trainable = true )
1019
+ AGNNConv ([init_beta], add_self_loops, trainable)
1014
1020
end
1015
1021
1016
1022
function (l:: AGNNConv )(g:: GNNGraph , x:: AbstractMatrix )
1017
1023
check_num_nodes (g, x)
1018
- g = add_self_loops (g)
1024
+ if l. add_self_loops
1025
+ g = add_self_loops (g)
1026
+ end
1019
1027
1020
1028
xn = x ./ sqrt .(sum (x .^ 2 , dims = 1 ))
1021
1029
cos_dist = apply_edges (xi_dot_xj, g, xi = xn, xj = xn)
1022
1030
α = softmax_edge_neighbors (g, l. β .* cos_dist)
1023
1031
1024
1032
x = propagate (g, + ; xj = x, e = α) do xi, xj, α
1025
- α .* xj
1033
+ α .* xj
1026
1034
end
1027
1035
1028
1036
return x
0 commit comments