Skip to content

Commit ed78e88

Browse files
authored
[GNNLux] Adding MegNetConv Layer (#480)
* megnet WIP * fix * fix * fix output * wip * temporary changes to run tests * testing * test * test * mean * mean * fix * fix * fix * added edge check * test * fix * Update basic_tests.jl * Update conv_tests.jl: Fixing tests * Update conv.jl: Back to old commit * Update conv_tests.jl: Fix tests * Update conv_tests.jl * Update conv.jl
1 parent 9e9ba9d commit ed78e88

File tree

4 files changed

+58
-2
lines changed

4 files changed

+58
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
33
using NNlib: NNlib, sigmoid, relu, swish
4+
using Statistics: mean
45
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
56
initialparameters, initialstates, parameterlength, statelength
67
using Lux: Lux, Chain, Dense, GRUCell,
@@ -30,7 +31,7 @@ export AGNNConv,
3031
GINConv,
3132
# GMMConv,
3233
GraphConv,
33-
# MEGNetConv,
34+
MEGNetConv,
3435
# NNConv,
3536
# ResGatedGraphConv,
3637
# SAGEConv,

GNNLux/src/layers/conv.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,45 @@ function Base.show(io::IO, l::GINConv)
628628
print(io, ", $(l.ϵ)")
629629
print(io, ")")
630630
end
631+
632+
@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
633+
in_dims::Int
634+
out_dims::Int
635+
ϕe::TE
636+
ϕv::TV
637+
aggr::A
638+
end
639+
640+
function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A}
641+
return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr)
642+
end
643+
644+
function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
645+
nin, nout = ch
646+
ϕe = Chain(Dense(3nin, nout, relu),
647+
Dense(nout, nout))
648+
649+
ϕv = Chain(Dense(nin + nout, nout, relu),
650+
Dense(nout, nout))
651+
652+
return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr)
653+
end
654+
655+
function (l::MEGNetConv)(g, x, e, ps, st)
656+
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
657+
ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv))
658+
m = (; ϕe, ϕv, aggr=l.aggr)
659+
return GNNlib.megnet_conv(m, g, x, e), st
660+
end
661+
662+
663+
LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,)
664+
665+
(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st)
666+
667+
function Base.show(io::IO, l::MEGNetConv)
668+
nin = l.in_dims
669+
nout = l.out_dims
670+
print(io, "MEGNetConv(", nin, " => ", nout)
671+
print(io, ")")
672+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,17 @@
9393
l = GINConv(nn, 0.5)
9494
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
9595
end
96+
97+
@testset "MEGNetConv" begin
98+
l = MEGNetConv(in_dims => out_dims)
99+
100+
ps = LuxCore.initialparameters(rng, l)
101+
st = LuxCore.initialstates(rng, l)
102+
103+
e = randn(rng, Float32, in_dims, g.num_edges)
104+
(x_new, e_new), st_new = l(g, x, e, ps, st)
105+
106+
@test size(x_new) == (out_dims, g.num_nodes)
107+
@test size(e_new) == (out_dims, g.num_edges)
108+
end
96109
end

GNNlib/src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,4 +724,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
724724
T1_out = T2_out
725725
end
726726
return h .+ l.bias
727-
end
727+
end

0 commit comments

Comments
 (0)