Skip to content

Commit 3c6079d

Browse files
implement MEGNetConv
1 parent b349bc4 commit 3c6079d

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ export
4646
GCNConv,
4747
GINConv,
4848
GraphConv,
49+
MEGNetConv,
4950
NNConv,
5051
ResGatedGraphConv,
5152
SAGEConv,
53+
5254

5355
# layers/pool
5456
GlobalPool,

src/layers/conv.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,48 @@ function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix)
823823
return x
824824
end
825825

826+
@doc raw"""
827+
MEGNetConv(in => out; aggr=mean)
828+
829+
Convolution from [Graph Networks as a Universal Machine Learning Framework for Molecules and Crystals](https://arxiv.org/pdf/1812.05055.pdf)
830+
paper.
831+
"""
832+
struct MEGNetConv <: GNNLayer
833+
ϕe
834+
ϕv
835+
aggr
836+
end
837+
838+
@functor MEGNetConv
839+
840+
function MEGNetConv(ch::Pair{Int,Int}; aggr=mean)
841+
nin, nout = ch
842+
ϕe = Chain(Dense(3nin, nout, relu),
843+
Dense(nout, nout))
844+
845+
ϕv = Chain(Dense(nin + nout, nout, relu),
846+
Dense(nout, nout))
847+
848+
MEGNetConv(ϕe, ϕv, aggr)
849+
end
850+
851+
function (l::MEGNetConv)(g::GNNGraph)
852+
x, e = l(g, node_features(g), edge_features(g))
853+
g = GNNGraph(g, ndata=x, edata=e)
854+
end
855+
)
856+
function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
857+
check_num_nodes(g, x)
858+
859+
= apply_edges(g, xi=x, xj=x, e=e) do xi, xj, e
860+
l.ϕe(vcat(xi, xj, e))
861+
end
862+
863+
xᵉ = aggregate_neighbors(g, l.aggr, ē)
864+
865+
= l.ϕv(vcat(x, xᵉ))
866+
867+
return x̄, ē
868+
end
869+
870+

test/layers/conv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,4 +184,12 @@
184184
test_layer(l, g, rtol=1e-5, outsize=(in_channel, g.num_nodes))
185185
end
186186
end
187+
188+
@testset "MEGNetConv" begin
189+
l = MEGNetConv(in_channel, edim => out_channel, tanh, aggr=+)
190+
for g in test_graphs
191+
g = GNNGraph(g, edata=rand(T, in_channel, g.num_edges))
192+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
193+
end
194+
end
187195
end

0 commit comments

Comments
 (0)