Skip to content

Commit 4636ca3

Browse files
Merge pull request #83 from CarloLucibello/cl/meg
implement MEGNetConv
2 parents 749e038 + c653600 commit 4636ca3

File tree

6 files changed

+170
-31
lines changed

6 files changed

+170
-31
lines changed

docs/src/api/messagepassing.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Pages = ["messagepassing.md"]
1515

1616
```@docs
1717
apply_edges
18+
aggregate_neighbors
1819
propagate
1920
```
2021

src/GraphNeuralNetworks.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
2121

2222
export
2323
# utils
24-
reduce_nodes, reduce_edges,
25-
softmax_nodes, softmax_edges,
26-
broadcast_nodes, broadcast_edges,
24+
reduce_nodes,
25+
reduce_edges,
26+
softmax_nodes,
27+
softmax_edges,
28+
broadcast_nodes,
29+
broadcast_edges,
2730
softmax_edge_neighbors,
2831

2932
# msgpass
30-
apply_edges, propagate,
31-
copy_xj, copy_xi, xi_dot_xj,
33+
apply_edges,
34+
aggregate_neighbors,
35+
propagate,
36+
copy_xj,
37+
copy_xi,
38+
xi_dot_xj,
3239

3340
# layers/basic
3441
GNNLayer,
@@ -46,9 +53,11 @@ export
4653
GCNConv,
4754
GINConv,
4855
GraphConv,
56+
MEGNetConv,
4957
NNConv,
5058
ResGatedGraphConv,
5159
SAGEConv,
60+
5261

5362
# layers/pool
5463
GlobalPool,

src/layers/conv.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,73 @@ function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
823823
return x
824824
end
825825

826+
@doc raw"""
827+
MEGNetConv(ϕe, ϕv; aggr=mean)
828+
MEGNetConv(in => out; aggr=mean)
829+
830+
Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
831+
paper. In the forward pass, takes as inputs node features `x` and edge features `e` and returns
832+
updated features `x̄, ē` according to
833+
834+
```math
835+
ē = ϕe(vcat(xi, xj, e))
836+
x̄ = ϕv(vcat(x, \square_{j\in \mathcal{N}(i)} ē_{j\to i}))
837+
```
838+
`aggr` defines the aggregation to be performed.
839+
840+
If the neural networks `ϕe` and `ϕv` are not provided, they will be constructed from
841+
the `in` and `out` arguments instead as multi-layer perceptron with one hidden layer and `relu`
842+
activations.
843+
````
844+
845+
# Examples
846+
847+
```julia
848+
g = rand_graph(10, 30)
849+
x = randn(3, 10)
850+
e = randn(3, 30)
851+
m = MEGNetConv(3 => 3)
852+
x̄, ē = m(g, x, e)
853+
```
854+
"""
855+
struct MEGNetConv <: GNNLayer
856+
ϕe
857+
ϕv
858+
aggr
859+
end
860+
861+
@functor MEGNetConv
862+
863+
MEGNetConv(ϕe, ϕv; aggr=mean) = MEGNetConv(ϕe, ϕv, aggr)
864+
865+
function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
866+
nin, nout = ch
867+
ϕe = Chain(Dense(3nin, nout, relu),
868+
Dense(nout, nout))
869+
870+
ϕv = Chain(Dense(nin + nout, nout, relu),
871+
Dense(nout, nout))
872+
873+
MEGNetConv(ϕe, ϕv; aggr)
874+
end
875+
876+
function (l::MEGNetConv)(g::GNNGraph)
877+
x, e = l(g, node_features(g), edge_features(g))
878+
g = GNNGraph(g, ndata=x, edata=e)
879+
end
880+
881+
function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
882+
check_num_nodes(g, x)
883+
884+
= apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
885+
l.ϕe(vcat(xi, xj, e))
886+
end
887+
888+
xᵉ = aggregate_neighbors(g, l.aggr, ē)
889+
890+
= l.ϕv(vcat(x, xᵉ))
891+
892+
return x̄, ē
893+
end
894+
895+

src/msgpass.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ l = GNNConv(10 => 20)
6161
l(g, x)
6262
```
6363
64-
See also [`apply_edges`](@ref).
64+
See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref).
6565
"""
6666
function propagate end
6767

@@ -103,7 +103,7 @@ such tensors.
103103
a batch of edges. The output of `f` has to be an array (or a named tuple of arrays)
104104
with the same batch size.
105105
106-
See also [`propagate`](@ref).
106+
See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref).
107107
"""
108108
function apply_edges end
109109

@@ -125,7 +125,19 @@ _gather(x::Nothing, i) = nothing
125125

126126

127127
## AGGREGATE NEIGHBORS
128+
@doc raw"""
129+
aggregate_neighbors(g::GNNGraph, aggr, m)
130+
131+
Given a graph `g`, edge features `m`, and an aggregation
132+
operator `aggr` (e.g `+, min, max, mean`), returns the new node
133+
features
134+
```math
135+
\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}
136+
```
128137
138+
Neighborhood aggregation is the second step of [`propagate`](@ref),
139+
where it comes after [`apply_edges`](@ref).
140+
"""
129141
function aggregate_neighbors(g::GNNGraph, aggr, m)
130142
s, t = edge_index(g)
131143
return _scatter(aggr, m, t)

test/layers/conv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,14 @@
184184
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
185185
end
186186
end
187+
188+
@testset "MEGNetConv" begin
189+
l = MEGNetConv(in_channel => out_channel, aggr=+)
190+
for g in test_graphs
191+
g = GNNGraph(g, edata=rand(T, in_channel, g.num_edges))
192+
test_layer(l, g, rtol=1e-5,
193+
outtype=:node_edge,
194+
outsize=((out_channel, g.num_nodes), (out_channel, g.num_edges)))
195+
end
196+
end
187197
end

test/test_utils.jl

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,66 +30,103 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
3030

3131
x = node_features(g)
3232
e = edge_features(g)
33+
use_edge_feat = !isnothing(e)
3334

3435
x64, e64, l64, g64 = to64.([x, e, l, g]) # needed for accurate FiniteDifferences' grad
3536
xgpu, egpu, lgpu, ggpu = gpu.([x, e, l, g])
3637

3738
f(l, g::GNNGraph) = l(g)
38-
f(l, g::GNNGraph, x::AbstractArray{Float32}) = isnothing(e) ? l(g, x) : l(g, x, e)
39-
f(l, g::GNNGraph, x::AbstractArray{Float64}) = isnothing(e64) ? l(g, x) : l(g, x, e64)
40-
f(l, g::GNNGraph, x::CuArray) = isnothing(e64) ? l(g, x) : l(g, x, egpu)
39+
f(l, g::GNNGraph, x, e) = use_edge_feat ? l(g, x, e) : l(g, x)
4140

4241
loss(l, g::GNNGraph) = if outtype == :node
4342
sum(node_features(f(l, g)))
4443
elseif outtype == :edge
4544
sum(edge_features(f(l, g)))
4645
elseif outtype == :graph
4746
sum(graph_features(f(l, g)))
47+
elseif outtype == :node_edge
48+
gnew = f(l, g)
49+
sum(node_features(gnew)) + sum(edge_features(gnew))
4850
end
4951

50-
loss(l, g::GNNGraph, x) = sum(f(l, g, x))
51-
loss(l, g::GNNGraph, x, e) = sum(l(g, x, e))
52+
function loss(l, g::GNNGraph, x, e)
53+
y = f(l, g, x, e)
54+
if outtype == :node_edge
55+
return sum(y[1]) + sum(y[2])
56+
else
57+
return sum(y)
58+
end
59+
end
5260

5361

5462
# TEST OUTPUT
55-
y = f(l, g, x)
56-
@test eltype(y) == eltype(x)
57-
@test all(isfinite, y)
58-
if !isnothing(outsize)
59-
@test size(y) == outsize
63+
y = f(l, g, x, e)
64+
if outtype == :node_edge
65+
@assert y isa Tuple
66+
@test eltype(y[1]) == eltype(x)
67+
@test eltype(y[2]) == eltype(e)
68+
@test all(isfinite, y[1])
69+
@test all(isfinite, y[2])
70+
if !isnothing(outsize)
71+
@test size(y[1]) == outsize[1]
72+
@test size(y[2]) == outsize[2]
73+
end
74+
else
75+
@test eltype(y) == eltype(x)
76+
@test all(isfinite, y)
77+
if !isnothing(outsize)
78+
@test size(y) == outsize
79+
end
6080
end
6181

6282
# test same output on different graph formats
6383
gcoo = GNNGraph(g, graph_type=:coo)
64-
ycoo = f(l, gcoo, x)
65-
@test ycoo y
66-
84+
ycoo = f(l, gcoo, x, e)
85+
if outtype == :node_edge
86+
@test ycoo[1] y[1]
87+
@test ycoo[2] y[2]
88+
else
89+
@test ycoo y
90+
end
91+
6792
g′ = f(l, g)
6893
if outtype == :node
6994
@test g′.ndata.x y
7095
elseif outtype == :edge
7196
@test g′.edata.e y
7297
elseif outtype == :graph
7398
@test g′.gdata.u y
99+
elseif outtype == :node_edge
100+
@test g′.ndata.x y[1]
101+
@test g′.edata.e y[2]
74102
else
75103
@error "wrong outtype $outtype"
76104
end
77105
if test_gpu
78-
ygpu = f(lgpu, ggpu, xgpu)
79-
@test ygpu isa CuArray
80-
@test eltype(ygpu) == eltype(xgpu)
81-
@test Array(ygpu) y
106+
ygpu = f(lgpu, ggpu, xgpu, egpu)
107+
if outtype == :node_edge
108+
@test ygpu[1] isa CuArray
109+
@test eltype(ygpu[1]) == eltype(xgpu)
110+
@test Array(ygpu[1]) y[1]
111+
@test ygpu[2] isa CuArray
112+
@test eltype(ygpu[2]) == eltype(xgpu)
113+
@test Array(ygpu[2]) y[2]
114+
else
115+
@test ygpu isa CuArray
116+
@test eltype(ygpu) == eltype(xgpu)
117+
@test Array(ygpu) y
118+
end
82119
end
83120

84121

85122
# TEST x INPUT GRADIENT
86-
= gradient(x -> loss(l, g, x), x)[1]
87-
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64), x64)[1]
123+
= gradient(x -> loss(l, g, x, e), x)[1]
124+
x̄_fd = FiniteDifferences.grad(fdm, x64 -> loss(l64, g64, x64, e64), x64)[1]
88125
@test eltype(x̄) == eltype(x)
89126
@test x̄_fd atol=atol rtol=rtol
90127

91128
if test_gpu
92-
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu), xgpu)[1]
129+
x̄gpu = gradient(xgpu -> loss(lgpu, ggpu, xgpu, egpu), xgpu)[1]
93130
@test x̄gpu isa CuArray
94131
@test eltype(x̄gpu) == eltype(x)
95132
@test Array(x̄gpu) x̄ atol=atol rtol=rtol
@@ -112,13 +149,13 @@ function test_layer(l, g::GNNGraph; atol = 1e-6, rtol = 1e-5,
112149
end
113150

114151

115-
# TEST LAYER GRADIENT - l(g, x)
116-
= gradient(l -> loss(l, g, x), l)[1]
117-
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64), l64)[1]
152+
# TEST LAYER GRADIENT - l(g, x, e)
153+
= gradient(l -> loss(l, g, x, e), l)[1]
154+
l̄_fd = FiniteDifferences.grad(fdm, l64 -> loss(l64, g64, x64, e64), l64)[1]
118155
test_approx_structs(l, l̄, l̄_fd; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
119156

120157
if test_gpu
121-
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu), lgpu)[1]
158+
l̄gpu = gradient(lgpu -> loss(lgpu, ggpu, xgpu, egpu), lgpu)[1]
122159
test_approx_structs(lgpu, l̄gpu, l̄; atol, rtol, broken_grad_fields, exclude_grad_fields, verbose)
123160
end
124161

0 commit comments

Comments
 (0)