Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using Statistics: mean
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, GRUCell,
Expand Down Expand Up @@ -30,7 +31,7 @@ export AGNNConv,
GINConv,
# GMMConv,
GraphConv,
# MEGNetConv,
MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
Expand Down
42 changes: 42 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,45 @@ function Base.show(io::IO, l::GINConv)
print(io, ", $(l.ϵ)")
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
ϕe::TE
ϕv::TV
aggr::A
end

function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A}
return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr)
end

function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
nin, nout = ch
ϕe = Chain(Dense(3nin, nout, relu),
Dense(nout, nout))

ϕv = Chain(Dense(nin + nout, nout, relu),
Dense(nout, nout))

return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr)
end

function (l::MEGNetConv)(g, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv))
m = (; ϕe, ϕv, aggr=l.aggr)
return GNNlib.megnet_conv(m, g, x, e), st
end


LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,)

(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function Base.show(io::IO, l::MEGNetConv)
nin = l.in_dims
nout = l.out_dims
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end
13 changes: 13 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,17 @@
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end

@testset "MEGNetConv" begin
l = MEGNetConv(in_dims => out_dims)

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

e = randn(rng, Float32, in_dims, g.num_edges)
(x_new, e_new), st_new = l(g, x, e, ps, st)

@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end
end
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,4 +721,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
T1_out = T2_out
end
return h .+ l.bias
end
end