Skip to content

Commit c7eb715

Browse files
more layers
1 parent 67a51f7 commit c7eb715

File tree

4 files changed

+48
-16
lines changed

4 files changed

+48
-16
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ using ConcreteStructs: @concrete
33
using NNlib: NNlib, sigmoid, relu, swish
44
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
55
initialparameters, initialstates, parameterlength, statelength
6-
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
6+
using Lux: Lux, Chain, Dense, GRUCell,
7+
glorot_uniform, zeros32,
8+
StatefulLuxLayer
79
using Reexport: @reexport
810
using Random: AbstractRNG
911
using GNNlib: GNNlib
@@ -25,7 +27,7 @@ export AGNNConv,
2527
GATv2Conv,
2628
GatedGraphConv,
2729
GCNConv,
28-
# GINConv,
30+
GINConv,
2931
# GMMConv,
3032
GraphConv
3133
# MEGNetConv,

GNNLux/src/layers/conv.jl

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
3838
end
3939

4040
LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
41-
LuxCore.statelength(d::GCNConv) = 0
4241
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)
4342

4443
function Base.show(io::IO, l::GCNConv)
@@ -518,7 +517,7 @@ function Base.show(io::IO, l::GATv2Conv)
518517
end
519518

520519

521-
@concrete struct GatedGraphConv <: GRULayer
520+
@concrete struct GatedGraphConv <: GNNLayer
522521
gru
523522
init_weight
524523
dims::Int
@@ -533,28 +532,48 @@ function GatedGraphConv(dims::Int, num_layers::Int;
533532
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
534533
end
535534

536-
LucCore.outputsize(l::GatedGraphConv) = (l.dims,)
535+
LuxCore.outputsize(l::GatedGraphConv) = (l.dims,)
537536

538537
function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
539538
gru = LuxCore.initialparameters(rng, l.gru)
540-
weight = l.init_weight(rng, l.dims, l.dims)
539+
weight = l.init_weight(rng, l.dims, l.dims, l.num_layers)
541540
return (; gru, weight)
542541
end
543542

544-
LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2
543+
LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers
545544

546-
function LuxCore.initialstates(rng::AbstractRNG, l::GatedGraphConv)
547-
return (; gru = LuxCore.initialstates(rng, l.gru))
548-
end
549-
550-
LuxCore.statelength(l::GatedGraphConv) = statelength(l.gru)
551545

552-
function (l::GatedGraphConv)(g, H, ps, st)
553-
GNNlib.gated_graph_conv(l, g, H)
546+
function (l::GatedGraphConv)(g, x, ps, st)
547+
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
548+
fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style
549+
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
550+
return GNNlib.gated_graph_conv(m, g, x), st
554551
end
555552

556553
function Base.show(io::IO, l::GatedGraphConv)
557554
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
558555
print(io, ", aggr=", l.aggr)
559556
print(io, ")")
560-
end
557+
end
558+
559+
@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
560+
nn <: AbstractExplicitLayer
561+
ϵ <: Real
562+
aggr
563+
end
564+
565+
GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
566+
567+
function (l::GINConv)(g, x, ps, st)
568+
nn = StatefulLuxLayer{true}(l.nn, ps, st)
569+
m = (; nn, l.ϵ, l.aggr)
570+
y = GNNlib.gin_conv(m, g, x)
571+
stnew = _getstate(nn)
572+
return y, stnew
573+
end
574+
575+
function Base.show(io::IO, l::GINConv)
576+
print(io, "GINConv($(l.nn)")
577+
print(io, ", $(l.ϵ)")
578+
print(io, ")")
579+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,15 @@
7777

7878
#TODO test edge
7979
end
80-
end
8180

81+
@testset "GatedGraphConv" begin
82+
l = GatedGraphConv(in_dims, 3)
83+
test_lux_layer(rng, l, g, x, outputsize=(in_dims,))
84+
end
85+
86+
@testset "GINConv" begin
87+
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
88+
l = GINConv(nn, 0.5)
89+
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
90+
end
91+
end

GNNLux/test/shared_testsetup.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
2828
@test LuxCore.statelength(l) == LuxCore.statelength(st)
2929

3030
y, st′ = l(g, x, ps, st)
31+
@test eltype(y) == eltype(x)
3132
if outputsize !== nothing
3233
@test LuxCore.outputsize(l) == outputsize
3334
end

0 commit comments

Comments
 (0)