diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index d8970095c..b24ed4118 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,6 +1,7 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish +using Statistics: mean using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, initialparameters, initialstates, parameterlength, statelength using Lux: Lux, Chain, Dense, GRUCell, @@ -30,7 +31,7 @@ export AGNNConv, GINConv, # GMMConv, GraphConv, - # MEGNetConv, + MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 83c3efddc..30564ae48 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -628,3 +628,45 @@ function Base.show(io::IO, l::GINConv) print(io, ", $(l.ϵ)") print(io, ")") end + +@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} + in_dims::Int + out_dims::Int + ϕe::TE + ϕv::TV + aggr::A +end + +function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A} + return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr) +end + +function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) + nin, nout = ch + ϕe = Chain(Dense(3nin, nout, relu), + Dense(nout, nout)) + + ϕv = Chain(Dense(nin + nout, nout, relu), + Dense(nout, nout)) + + return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr) +end + +function (l::MEGNetConv)(g, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) + m = (; ϕe, ϕv, aggr=l.aggr) + return GNNlib.megnet_conv(m, g, x, e), st +end + + +LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,) + +(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function Base.show(io::IO, l::MEGNetConv) + nin = l.in_dims + nout = l.out_dims + print(io, "MEGNetConv(", nin, " => ", nout) + print(io, ")") +end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 9f010f39e..877e6e90b 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -93,4 +93,17 @@ l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + + @testset "MEGNetConv" begin + l = MEGNetConv(in_dims => out_dims) + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + e = randn(rng, Float32, in_dims, g.num_edges) + (x_new, e_new), st_new = l(g, x, e, ps, st) + + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) + end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 50b5b34aa..0b7dd2499 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -721,4 +721,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) T1_out = T2_out end return h .+ l.bias -end \ No newline at end of file +end