@@ -50,7 +50,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
50
50
# @assert all(>(0), degree(g, T, dir=:in))
51
51
c = 1 ./ sqrt .(degree (g, T, dir= :in ))
52
52
x = x .* c'
53
- x = propagate (copyxj , g, + , xj= x)
53
+ x = propagate (copy_xj , g, + , xj= x)
54
54
x = x .* c'
55
55
if Dout >= Din
56
56
x = l. weight * x
179
179
180
180
function (l:: GraphConv )(g:: GNNGraph , x:: AbstractMatrix )
181
181
check_num_nodes (g, x)
182
- m = propagate (copyxj , g, l. aggr, xj= x)
182
+ m = propagate (copy_xj , g, l. aggr, xj= x)
183
183
x = l. σ .(l. weight1 * x .+ l. weight2 * m .+ l. bias)
184
184
return x
185
185
end
@@ -206,7 +206,7 @@ Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.
206
206
207
207
Implements the operation
208
208
```math
209
- \m athbf{x}_i' = \s um_{j \i n N(i)} \a lpha_{ij} W \m athbf{x}_j
209
+ \m athbf{x}_i' = \s um_{j \i n N(i) \c up \{ i \} } \a lpha_{ij} W \m athbf{x}_j
210
210
```
211
211
where the attention coefficients ``\a lpha_{ij}`` are given by
212
212
```math
@@ -338,7 +338,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
338
338
end
339
339
for i = 1 : l. num_layers
340
340
M = view (l. weight, :, :, i) * H
341
- M = propagate (copyxj , g, l. aggr; xj= M)
341
+ M = propagate (copy_xj , g, l. aggr; xj= M)
342
342
H, _ = l. gru (H, M)
343
343
end
344
344
H
@@ -420,7 +420,7 @@ GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
420
420
421
421
function (l:: GINConv )(g:: GNNGraph , x:: AbstractMatrix )
422
422
check_num_nodes (g, x)
423
- m = propagate (copyxj , g, l. aggr, xj= x)
423
+ m = propagate (copy_xj , g, l. aggr, xj= x)
424
424
l. nn ((1 + ofeltype (x, l. ϵ)) * x + m)
425
425
end
426
426
542
542
543
543
function (l:: SAGEConv )(g:: GNNGraph , x:: AbstractMatrix )
544
544
check_num_nodes (g, x)
545
- m = propagate (copyxj , g, l. aggr, xj= x)
545
+ m = propagate (copy_xj , g, l. aggr, xj= x)
546
546
x = l. σ .(l. weight * vcat (x, m) .+ l. bias)
547
547
return x
548
548
end
@@ -711,3 +711,56 @@ function Base.show(io::IO, l::CGConv)
711
711
print (io, " , residual=$(l. residual) " )
712
712
print (io, " )" )
713
713
end
714
+
715
+
716
+ @doc raw """
717
+ AGNNConv(init_beta=1f0)
718
+
719
+ Attention-based Graph Neural Network layer from paper [Attention-based
720
+ Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
721
+
722
+ THe forward pass is given by
723
+ ```math
724
+ \m athbf{x}_i' = \s um_{j \i n {N(i) \c up \{ i\} } \a lpha_{ij} W \m athbf{x}_j
725
+ ```
726
+ where the attention coefficients ``\a lpha_{ij}`` are given by
727
+ ```math
728
+ \a lpha_{ij} =\f rac{e^{\b eta \c os(\m athbf{x}_i, \m athbf{x}_j)}}
729
+ {\s um_{j'}e^{\b eta \c os(\m athbf{x}_i, \m athbf{x}_j'}}
730
+ ```
731
+ with the cosine distance defined by
732
+ ```math
733
+ \c os(\m athbf{x}_i, \m athbf{x}_j) =
734
+ \m athbf{x}_i \c dot \m athbf{x}_j / \l Vert\m athbf{x}_i\r Vert \l Vert\m athbf{x}_j\r Vert``
735
+ ```
736
+ and ``\b eta`` a trainable parameter.
737
+
738
+ # Arguments
739
+
740
+ - `init_beta`: The initial value of ``\b eta``.
741
+ """
742
+ struct AGNNConv{A<: AbstractVector } <: GNNLayer
743
+ β:: A
744
+ end
745
+
746
+ @functor AGNNConv
747
+
748
+ function AGNNConv (init_beta = 1f0 )
749
+ AGNNConv ([init_beta])
750
+ end
751
+
752
+ function (l:: AGNNConv )(g:: GNNGraph , x:: AbstractMatrix )
753
+ check_num_nodes (g, x)
754
+ g = add_self_loops (g)
755
+
756
+ xn = x ./ sqrt .(sum (x.^ 2 , dims= 1 ))
757
+ cos_dist = apply_edges (xi_dot_xj, g, xi= xn, xj= xn)
758
+ α = softmax_edge_neighbors (g, l. β .* cos_dist)
759
+
760
+ x = propagate (g, + ; xj= x, e= α) do xi, xj, α
761
+ α .* xj
762
+ end
763
+
764
+ return x
765
+ end
766
+
0 commit comments