Skip to content

Commit 9d4346b

Browse files
fix AGNNConv (#328)
1 parent bb962f5 commit 9d4346b

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

src/deprecations.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11

2+
@deprecate AGNNConv(init_beta) AGNNConv(; init_beta)

src/layers/conv.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -978,14 +978,14 @@ function Base.show(io::IO, l::CGConv)
978978
end
979979

980980
@doc raw"""
981-
AGNNConv(init_beta=1f0)
981+
AGNNConv(; init_beta=1.0f0, trainable=true, add_self_loops=true)
982982
983983
Attention-based Graph Neural Network layer from paper [Attention-based
984984
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
985985
986986
The forward pass is given by
987987
```math
988-
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}}} \alpha_{ij} W \mathbf{x}_j
988+
\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} \mathbf{x}_j
989989
```
990990
where the attention coefficients ``\alpha_{ij}`` are given by
991991
```math
@@ -997,32 +997,40 @@ with the cosine distance defined by
997997
\cos(\mathbf{x}_i, \mathbf{x}_j) =
998998
\frac{\mathbf{x}_i \cdot \mathbf{x}_j}{\lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert}
999999
```
1000-
and ``\beta`` a trainable parameter.
1000+
and ``\beta`` a trainable parameter if `trainable=true`.
10011001
10021002
# Arguments
10031003
1004-
- `init_beta`: The initial value of ``\beta``.
1004+
- `init_beta`: The initial value of ``\beta``. Default 1.0f0.
1005+
- `trainable`: If true, ``\beta`` is trainable. Default `true`.
1006+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `true`.
10051007
"""
10061008
struct AGNNConv{A <: AbstractVector} <: GNNLayer
10071009
β::A
1010+
add_self_loops::Bool
1011+
trainable::Bool
10081012
end
10091013

10101014
@functor AGNNConv
10111015

1012-
function AGNNConv(init_beta = 1.0f0)
1013-
AGNNConv([init_beta])
1016+
Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;)
1017+
1018+
function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true)
1019+
AGNNConv([init_beta], add_self_loops, trainable)
10141020
end
10151021

10161022
function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
10171023
check_num_nodes(g, x)
1018-
g = add_self_loops(g)
1024+
if l.add_self_loops
1025+
g = add_self_loops(g)
1026+
end
10191027

10201028
xn = x ./ sqrt.(sum(x .^ 2, dims = 1))
10211029
cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn)
10221030
α = softmax_edge_neighbors(g, l.β .* cos_dist)
10231031

10241032
x = propagate(g, +; xj = x, e = α) do xi, xj, α
1025-
α .* xj
1033+
α .* xj
10261034
end
10271035

10281036
return x

test/layers/conv.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,17 @@ end
244244
end
245245

246246
@testset "AGNNConv" begin
247-
l = AGNNConv()
247+
l = AGNNConv(trainable=false, add_self_loops=false)
248248
@test l.β == [1.0f0]
249+
@test l.add_self_loops == false
250+
@test l.trainable == false
251+
Flux.trainable(l) == (;)
252+
253+
l = AGNNConv(init_beta=2.0f0)
254+
@test l.β == [2.0f0]
255+
@test l.add_self_loops == true
256+
@test l.trainable == true
257+
Flux.trainable(l) == (; β = [1f0])
249258
for g in test_graphs
250259
test_layer(l, g, rtol = RTOL_HIGH, outsize = (in_channel, g.num_nodes))
251260
end

0 commit comments

Comments
 (0)