Skip to content

Commit e5e919c

Browse files
Merge pull request #68 from CarloLucibello/cl/agnn
add AGNNConv
2 parents 651f761 + c2a1642 commit e5e919c

File tree

8 files changed

+95
-25
lines changed

8 files changed

+95
-25
lines changed

docs/src/api/messagepassing.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ propagate
2121
## Built-in message functions
2222

2323
```@docs
24-
copyxj
24+
copy_xi
25+
copy_xj
26+
xi_dot_xj
2527
```

docs/src/messagepassing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,4 @@ See the [`GATConv`](@ref) implementation [here](https://github.com/CarloLucibell
8888
## Built-in message functions
8989

9090
In order to exploit optimized specializations of the [`propagate`](@ref), it is recommended
91-
to use built-in message functions such as [`copyxj`](@ref) whenever possible.
91+
to use built-in message functions such as [`copy_xj`](@ref) whenever possible.

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,15 @@ export
2828

2929
# msgpass
3030
apply_edges, propagate,
31-
copyxj,
31+
copy_xj, copy_xi, xi_dot_xj,
3232

3333
# layers/basic
3434
GNNLayer,
3535
GNNChain,
3636
WithGraph,
3737

3838
# layers/conv
39+
AGNNConv,
3940
CGConv,
4041
ChebConv,
4142
EdgeConv,

src/deprecations.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
# Deprecated in v0.1
1+
## Deprecated in v0.2
22

3-
@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)
4-
5-
6-
# Deprecated in v0.2
7-
# TODO check if argument order is exact
83
function compute_message end
94
function update_node end
105
function update_edge end
@@ -29,3 +24,7 @@ function propagate(l::GNNLayer, g::GNNGraph, aggr, x, e=nothing)
2924
e = update_edge(l, e, m)
3025
return x, e
3126
end
27+
28+
## Deprecated in v0.3
29+
30+
@deprecate copyxj(xi, xj, e) copy_xj(xi, xj, e)

src/layers/conv.jl

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
5050
# @assert all(>(0), degree(g, T, dir=:in))
5151
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5252
x = x .* c'
53-
x = propagate(copyxj, g, +, xj=x)
53+
x = propagate(copy_xj, g, +, xj=x)
5454
x = x .* c'
5555
if Dout >= Din
5656
x = l.weight * x
@@ -179,7 +179,7 @@ end
179179

180180
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
181181
check_num_nodes(g, x)
182-
m = propagate(copyxj, g, l.aggr, xj=x)
182+
m = propagate(copy_xj, g, l.aggr, xj=x)
183183
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
184184
return x
185185
end
@@ -206,7 +206,7 @@ Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.
206206
207207
Implements the operation
208208
```math
209-
\mathbf{x}_i' = \sum_{j \in N(i)} \alpha_{ij} W \mathbf{x}_j
209+
\mathbf{x}_i' = \sum_{j \in N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
210210
```
211211
where the attention coefficients ``\alpha_{ij}`` are given by
212212
```math
@@ -338,7 +338,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
338338
end
339339
for i = 1:l.num_layers
340340
M = view(l.weight, :, :, i) * H
341-
M = propagate(copyxj, g, l.aggr; xj=M)
341+
M = propagate(copy_xj, g, l.aggr; xj=M)
342342
H, _ = l.gru(H, M)
343343
end
344344
H
@@ -420,7 +420,7 @@ GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
420420

421421
function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
422422
check_num_nodes(g, x)
423-
m = propagate(copyxj, g, l.aggr, xj=x)
423+
m = propagate(copy_xj, g, l.aggr, xj=x)
424424
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
425425
end
426426

@@ -542,7 +542,7 @@ end
542542

543543
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
544544
check_num_nodes(g, x)
545-
m = propagate(copyxj, g, l.aggr, xj=x)
545+
m = propagate(copy_xj, g, l.aggr, xj=x)
546546
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
547547
return x
548548
end
@@ -711,3 +711,56 @@ function Base.show(io::IO, l::CGConv)
711711
print(io, ", residual=$(l.residual)")
712712
print(io, ")")
713713
end
714+
715+
716+
@doc raw"""
717+
AGNNConv(init_beta=1f0)
718+
719+
Attention-based Graph Neural Network layer from paper [Attention-based
720+
Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
721+
722+
THe forward pass is given by
723+
```math
724+
\mathbf{x}_i' = \sum_{j \in {N(i) \cup \{i\}} \alpha_{ij} W \mathbf{x}_j
725+
```
726+
where the attention coefficients ``\alpha_{ij}`` are given by
727+
```math
728+
\alpha_{ij} =\frac{e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j)}}
729+
{\sum_{j'}e^{\beta \cos(\mathbf{x}_i, \mathbf{x}_j'}}
730+
```
731+
with the cosine distance defined by
732+
```math
733+
\cos(\mathbf{x}_i, \mathbf{x}_j) =
734+
\mathbf{x}_i \cdot \mathbf{x}_j / \lVert\mathbf{x}_i\rVert \lVert\mathbf{x}_j\rVert``
735+
```
736+
and ``\beta`` a trainable parameter.
737+
738+
# Arguments
739+
740+
- `init_beta`: The initial value of ``\beta``.
741+
"""
742+
struct AGNNConv{A<:AbstractVector} <: GNNLayer
743+
β::A
744+
end
745+
746+
@functor AGNNConv
747+
748+
function AGNNConv(init_beta = 1f0)
749+
AGNNConv([init_beta])
750+
end
751+
752+
function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
753+
check_num_nodes(g, x)
754+
g = add_self_loops(g)
755+
756+
xn = x ./ sqrt.(sum(x.^2, dims=1))
757+
cos_dist = apply_edges(xi_dot_xj, g, xi=xn, xj=xn)
758+
α = softmax_edge_neighbors(g, l.β .* cos_dist)
759+
760+
x = propagate(g, +; xj=x, e=α) do xi, xj, α
761+
α .* xj
762+
end
763+
764+
return x
765+
end
766+

src/msgpass.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,26 +139,32 @@ _scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)
139139

140140
### SPECIALIZATIONS OF PROPAGATE ###
141141
"""
142-
copyxj(xi, xj, e) = xj
142+
copy_xj(xi, xj, e) = xj
143143
"""
144-
copyxj(xi, xj, e) = xj
144+
copy_xj(xi, xj, e) = xj
145145

146-
# copyxi(xi, xj, e) = xi
147-
# ximulxj(xi, xj, e) = xi .* xj
148-
# xiaddxj(xi, xj, e) = xi .+ xj
146+
"""
147+
copy_xi(xi, xj, e) = xi
148+
"""
149+
copy_xi(xi, xj, e) = xi
150+
151+
"""
152+
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
153+
"""
154+
xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
149155

150156

151-
function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
157+
function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e)
152158
A = adjacency_matrix(g)
153159
return xj * A
154160
end
155161

156162
## avoid the fast path on gpu until we have better cuda support
157-
function propagate(::typeof(copyxj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
158-
propagate((xi,xj,e)->copyxj(xi,xj,e), g, +, xi, xj, e)
163+
function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T,SPARSE_T}}, ::typeof(+), xi, xj::AnyCuMatrix, e)
164+
propagate((xi,xj,e)->copy_xj(xi,xj,e), g, +, xi, xj, e)
159165
end
160166

161-
# function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
167+
# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e)
162168
# A = adjacency_matrix(g)
163169
# D = compute_degree(A)
164170
# return xj * A * D

test/deprecations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
end
2424

2525
function new_forward(l, g, x)
26-
x = propagate(copyxj, g, +, xj=x)
26+
x = propagate(copy_xj, g, +, xj=x)
2727
return l.σ.(l.weight * x .+ l.bias)
2828
end
2929

test/layers/conv.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,13 @@
158158
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
159159
end
160160
end
161+
162+
163+
@testset "AGNNConv" begin
164+
l = AGNNConv()
165+
l.β == [1f0]
166+
for g in test_graphs
167+
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
168+
end
169+
end
161170
end

0 commit comments

Comments
 (0)