Skip to content

Commit d1831e7

Browse files
[GNNLux] add GMMConv, ResGatedGraphConv (#494)
1 parent 2313a96 commit d1831e7

File tree

6 files changed

+141
-13
lines changed

6 files changed

+141
-13
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ export AGNNConv,
3030
GatedGraphConv,
3131
GCNConv,
3232
GINConv,
33-
# GMMConv,
33+
GMMConv,
3434
GraphConv,
3535
MEGNetConv,
3636
NNConv,
37-
# ResGatedGraphConv,
37+
ResGatedGraphConv,
3838
# SAGEConv,
3939
SGConv
4040
# TAGConv,

GNNLux/src/layers/conv.jl

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,68 @@ function Base.show(io::IO, l::GINConv)
628628
print(io, ")")
629629
end
630630

631+
@concrete struct GMMConv <: GNNLayer
632+
σ
633+
ch::Pair{NTuple{2, Int}, Int}
634+
K::Int
635+
residual::Bool
636+
init_weight
637+
init_bias
638+
use_bias::Bool
639+
dense_x
640+
end
641+
642+
function GMMConv(ch::Pair{NTuple{2, Int}, Int},
643+
σ = identity;
644+
K::Int = 1,
645+
residual = false,
646+
init_weight = glorot_uniform,
647+
init_bias = zeros32,
648+
use_bias = true)
649+
dense_x = Dense(ch[1][1] => ch[2] * K, use_bias = false)
650+
return GMMConv(σ, ch, K, residual, init_weight, init_bias, use_bias, dense_x)
651+
end
652+
653+
654+
function LuxCore.initialparameters(rng::AbstractRNG, l::GMMConv)
655+
ein = l.ch[1][2]
656+
mu = l.init_weight(rng, ein, l.K)
657+
sigma_inv = l.init_weight(rng, ein, l.K)
658+
ps = (; mu, sigma_inv, dense_x = LuxCore.initialparameters(rng, l.dense_x))
659+
if l.use_bias
660+
bias = l.init_bias(rng, l.ch[2])
661+
ps = (; ps..., bias)
662+
end
663+
return ps
664+
end
665+
666+
LuxCore.outputsize(l::GMMConv) = (l.ch[2],)
667+
668+
function LuxCore.parameterlength(l::GMMConv)
669+
n = 2 * l.ch[1][2] * l.K
670+
n += parameterlength(l.dense_x)
671+
if l.use_bias
672+
n += l.ch[2]
673+
end
674+
return n
675+
end
676+
677+
function (l::GMMConv)(g::GNNGraph, x, e, ps, st)
678+
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
679+
m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps))
680+
return GNNlib.gmm_conv(m, g, x, e), st
681+
end
682+
683+
function Base.show(io::IO, l::GMMConv)
684+
(nin, ein), out = l.ch
685+
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
686+
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
687+
print(io, ", K=", l.K)
688+
print(io, ", residual=", l.residual)
689+
l.use_bias == true || print(io, ", use_bias=false")
690+
print(io, ")")
691+
end
692+
631693
@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
632694
in_dims::Int
633695
out_dims::Int
@@ -712,6 +774,8 @@ function LuxCore.parameterlength(l::NNConv)
712774
return n
713775
end
714776

777+
LuxCore.outputsize(l::NNConv) = (l.out_dims,)
778+
715779
LuxCore.statelength(l::NNConv) = statelength(l.nn)
716780

717781
function (l::NNConv)(g, x, e, ps, st)
@@ -723,7 +787,59 @@ function (l::NNConv)(g, x, e, ps, st)
723787
end
724788

725789
function Base.show(io::IO, l::NNConv)
726-
print(io, "NNConv($(l.nn)")
790+
print(io, "NNConv($(l.in_dims) => $(l.out_dims), $(l.nn)")
791+
l.σ == identity || print(io, ", ", l.σ)
792+
l.use_bias || print(io, ", use_bias=false")
793+
print(io, ")")
794+
end
795+
796+
@concrete struct ResGatedGraphConv <: GNNLayer
797+
in_dims::Int
798+
out_dims::Int
799+
σ
800+
init_bias
801+
init_weight
802+
use_bias::Bool
803+
end
804+
805+
function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
806+
init_weight = glorot_uniform,
807+
init_bias = zeros32,
808+
use_bias::Bool = true)
809+
in_dims, out_dims = ch
810+
return ResGatedGraphConv(in_dims, out_dims, σ, init_bias, init_weight, use_bias)
811+
end
812+
813+
function LuxCore.initialparameters(rng::AbstractRNG, l::ResGatedGraphConv)
814+
A = l.init_weight(rng, l.out_dims, l.in_dims)
815+
B = l.init_weight(rng, l.out_dims, l.in_dims)
816+
U = l.init_weight(rng, l.out_dims, l.in_dims)
817+
V = l.init_weight(rng, l.out_dims, l.in_dims)
818+
if l.use_bias
819+
bias = l.init_bias(rng, l.out_dims)
820+
return (; A, B, U, V, bias)
821+
else
822+
return (; A, B, U, V)
823+
end
824+
end
825+
826+
function LuxCore.parameterlength(l::ResGatedGraphConv)
827+
n = 4 * l.in_dims * l.out_dims
828+
if l.use_bias
829+
n += l.out_dims
830+
end
831+
return n
832+
end
833+
834+
LuxCore.outputsize(l::ResGatedGraphConv) = (l.out_dims,)
835+
836+
function (l::ResGatedGraphConv)(g, x, ps, st)
837+
m = (; ps.A, ps.B, ps.U, ps.V, bias = _getbias(ps), l.σ)
838+
return GNNlib.res_gated_graph_conv(m, g, x), st
839+
end
840+
841+
function Base.show(io::IO, l::ResGatedGraphConv)
842+
print(io, "ResGatedGraphConv(", l.in_dims, " => ", l.out_dims)
727843
l.σ == identity || print(io, ", ", l.σ)
728844
l.use_bias || print(io, ", use_bias=false")
729845
print(io, ")")

GNNLux/test/layers/conv_tests.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,18 @@
120120
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
121121
x = randn(Float32, n_in, g2.num_nodes)
122122
e = randn(Float32, n_in_edge, g2.num_edges)
123+
test_lux_layer(rng, l, g2, x; outputsize=(n_out,), e, container=true)
124+
end
123125

124-
ps = LuxCore.initialparameters(rng, l)
125-
st = LuxCore.initialstates(rng, l)
126+
@testset "GMMConv" begin
127+
ein_dims = 4
128+
e = randn(rng, Float32, ein_dims, g.num_edges)
129+
l = GMMConv((in_dims, ein_dims) => out_dims, tanh; K = 2, residual = false)
130+
test_lux_layer(rng, l, g, x; outputsize=(out_dims,), e)
131+
end
126132

127-
y, st′ = l(g2, x, e, ps, st)
128-
129-
@test size(y) == (n_out, g2.num_nodes)
130-
end
133+
@testset "ResGatedGraphConv" begin
134+
l = ResGatedGraphConv(in_dims => out_dims, tanh)
135+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
136+
end
131137
end

GNNLux/test/shared_testsetup.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
4242
@test size(y) == (outputsize..., g.num_nodes)
4343
end
4444

45-
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
45+
if e !== nothing
46+
loss = (x, ps) -> sum(first(l(g, x, e, ps, st)))
47+
else
48+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
49+
end
4650
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
4751
end
4852

GNNlib/src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
389389
m = propagate(e_mul_xj, g, mean, xj = xj, e = w)
390390
m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes)
391391

392-
m = l.σ(m .+ l.bias)
392+
m = l.σ.(m .+ l.bias)
393393

394394
if l.residual
395395
if size(x, 1) == size(m, 1)

src/layers/conv.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,9 @@ end
717717
function Base.show(io::IO, l::NNConv)
718718
out, in = size(l.weight)
719719
print(io, "NNConv($in => $out")
720-
print(io, ", aggr=", l.aggr)
720+
print(io, ", ", l.nn)
721+
l.σ == identity || print(io, ", ", l.σ)
722+
(l.aggr == +) || print(io, "; aggr=", l.aggr)
721723
print(io, ")")
722724
end
723725

@@ -1136,7 +1138,7 @@ function Base.show(io::IO, l::GMMConv)
11361138
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
11371139
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
11381140
print(io, ", K=", l.K)
1139-
l.residual == true || print(io, ", residual=", l.residual)
1141+
print(io, ", residual=", l.residual)
11401142
print(io, ")")
11411143
end
11421144

0 commit comments

Comments
 (0)