Skip to content

Commit c896eda

Browse files
authored
[GNNLux] Adding NNConv Layer (#491)
* nnlux * Update conv_tests.jl: test * fix * Update conv.jl: show * Update shared_testsetup.jl: changed to e
1 parent 5715b26 commit c896eda

File tree

4 files changed

+89
-4
lines changed

4 files changed

+89
-4
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ export AGNNConv,
3333
# GMMConv,
3434
GraphConv,
3535
MEGNetConv,
36-
# NNConv,
36+
NNConv,
3737
# ResGatedGraphConv,
3838
# SAGEConv,
3939
SGConv

GNNLux/src/layers/conv.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,62 @@ function Base.show(io::IO, l::MEGNetConv)
669669
print(io, "MEGNetConv(", nin, " => ", nout)
670670
print(io, ")")
671671
end
672+
673+
@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
674+
nn <: AbstractLuxLayer
675+
aggr
676+
in_dims::Int
677+
out_dims::Int
678+
use_bias::Bool
679+
init_weight
680+
init_bias
681+
σ
682+
end
683+
684+
function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
685+
aggr = +,
686+
init_bias = zeros32,
687+
use_bias::Bool = true,
688+
init_weight = glorot_uniform)
689+
in_dims, out_dims = ch
690+
σ = NNlib.fast_act(σ)
691+
return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ)
692+
end
693+
694+
function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
695+
weight = l.init_weight(rng, l.out_dims, l.in_dims)
696+
ps = (; nn = LuxCore.initialparameters(rng, l.nn), weight)
697+
if l.use_bias
698+
ps = (; ps..., bias = l.init_bias(rng, l.out_dims))
699+
end
700+
return ps
701+
end
702+
703+
function LuxCore.initialstates(rng::AbstractRNG, l::NNConv)
704+
return (; nn = LuxCore.initialstates(rng, l.nn))
705+
end
706+
707+
function LuxCore.parameterlength(l::NNConv)
708+
n = parameterlength(l.nn) + l.in_dims * l.out_dims
709+
if l.use_bias
710+
n += l.out_dims
711+
end
712+
return n
713+
end
714+
715+
LuxCore.statelength(l::NNConv) = statelength(l.nn)
716+
717+
function (l::NNConv)(g, x, e, ps, st)
718+
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
719+
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ)
720+
y = GNNlib.nn_conv(m, g, x, e)
721+
stnew = _getstate(nn)
722+
return y, stnew
723+
end
724+
725+
function Base.show(io::IO, l::NNConv)
726+
print(io, "NNConv($(l.nn)")
727+
l.σ == identity || print(io, ", ", l.σ)
728+
l.use_bias || print(io, ", use_bias=false")
729+
print(io, ")")
730+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,26 @@
106106
@test size(x_new) == (out_dims, g.num_nodes)
107107
@test size(e_new) == (out_dims, g.num_edges)
108108
end
109+
110+
@testset "NNConv" begin
111+
n_in = 3
112+
n_in_edge = 10
113+
n_out = 5
114+
115+
s = [1,1,2,3]
116+
t = [2,3,1,1]
117+
g2 = GNNGraph(s, t)
118+
119+
nn = Dense(n_in_edge => n_out * n_in)
120+
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
121+
x = randn(Float32, n_in, g2.num_nodes)
122+
e = randn(Float32, n_in_edge, g2.num_edges)
123+
124+
ps = LuxCore.initialparameters(rng, l)
125+
st = LuxCore.initialstates(rng, l)
126+
127+
y, st′ = l(g2, x, e, ps, st)
128+
129+
@test size(y) == (n_out, g2.num_nodes)
130+
end
109131
end

GNNLux/test/shared_testsetup.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export test_lux_layer
1414

1515
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
1616
outputsize=nothing, sizey=nothing, container=false,
17-
atol=1.0f-2, rtol=1.0f-2)
17+
atol=1.0f-2, rtol=1.0f-2, e=nothing)
1818

1919
if container
2020
@test l isa GNNContainerLayer
@@ -27,7 +27,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
2727
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
2828
@test LuxCore.statelength(l) == LuxCore.statelength(st)
2929

30-
y, st′ = l(g, x, ps, st)
30+
if e !== nothing
31+
y, st′ = l(g, x, e, ps, st)
32+
else
33+
y, st′ = l(g, x, ps, st)
34+
end
3135
@test eltype(y) == eltype(x)
3236
if outputsize !== nothing
3337
@test LuxCore.outputsize(l) == outputsize
@@ -42,4 +46,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
4246
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
4347
end
4448

45-
end
49+
end

0 commit comments

Comments
 (0)