Skip to content

Commit 67a51f7

Browse files
more layer
1 parent 4b4477e commit 67a51f7

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module GNNLux
22
using ConcreteStructs: @concrete
33
using NNlib: NNlib, sigmoid, relu, swish
4-
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
4+
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
5+
initialparameters, initialstates, parameterlength, statelength
56
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
67
using Reexport: @reexport
78
using Random: AbstractRNG
@@ -22,7 +23,7 @@ export AGNNConv,
2223
DConv,
2324
GATConv,
2425
GATv2Conv,
25-
# GatedGraphConv,
26+
GatedGraphConv,
2627
GCNConv,
2728
# GINConv,
2829
# GMMConv,

GNNLux/src/layers/conv.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,4 +515,46 @@ function Base.show(io::IO, l::GATv2Conv)
515515
l.σ == identity || print(io, ", ", l.σ)
516516
print(io, ", negative_slope=", l.negative_slope)
517517
print(io, ")")
518+
end
519+
520+
521+
@concrete struct GatedGraphConv <: GRULayer
522+
gru
523+
init_weight
524+
dims::Int
525+
num_layers::Int
526+
aggr
527+
end
528+
529+
530+
function GatedGraphConv(dims::Int, num_layers::Int;
531+
aggr = +, init_weight = glorot_uniform)
532+
gru = GRUCell(dims => dims)
533+
return GatedGraphConv(gru, init_weight, dims, num_layers, aggr)
534+
end
535+
536+
LucCore.outputsize(l::GatedGraphConv) = (l.dims,)
537+
538+
function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv)
539+
gru = LuxCore.initialparameters(rng, l.gru)
540+
weight = l.init_weight(rng, l.dims, l.dims)
541+
return (; gru, weight)
542+
end
543+
544+
LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2
545+
546+
function LuxCore.initialstates(rng::AbstractRNG, l::GatedGraphConv)
547+
return (; gru = LuxCore.initialstates(rng, l.gru))
548+
end
549+
550+
LuxCore.statelength(l::GatedGraphConv) = statelength(l.gru)
551+
552+
function (l::GatedGraphConv)(g, H, ps, st)
553+
GNNlib.gated_graph_conv(l, g, H)
554+
end
555+
556+
function Base.show(io::IO, l::GatedGraphConv)
557+
print(io, "GatedGraphConv($(l.dims), $(l.num_layers)")
558+
print(io, ", aggr=", l.aggr)
559+
print(io, ")")
518560
end

0 commit comments

Comments
 (0)