@@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
38
38
end
39
39
40
40
LuxCore. parameterlength (l:: GCNConv ) = l. use_bias ? l. in_dims * l. out_dims + l. out_dims : l. in_dims * l. out_dims
41
- LuxCore. statelength (d:: GCNConv ) = 0
42
41
LuxCore. outputsize (d:: GCNConv ) = (d. out_dims,)
43
42
44
43
function Base. show (io:: IO , l:: GCNConv )
@@ -518,7 +517,7 @@ function Base.show(io::IO, l::GATv2Conv)
518
517
end
519
518
520
519
521
- @concrete struct GatedGraphConv <: GRULayer
520
+ @concrete struct GatedGraphConv <: GNNLayer
522
521
gru
523
522
init_weight
524
523
dims:: Int
@@ -533,28 +532,48 @@ function GatedGraphConv(dims::Int, num_layers::Int;
533
532
return GatedGraphConv (gru, init_weight, dims, num_layers, aggr)
534
533
end
535
534
536
- LucCore . outputsize (l:: GatedGraphConv ) = (l. dims,)
535
+ LuxCore . outputsize (l:: GatedGraphConv ) = (l. dims,)
537
536
538
537
function LuxCore. initialparameters (rng:: AbstractRNG , l:: GatedGraphConv )
539
538
gru = LuxCore. initialparameters (rng, l. gru)
540
- weight = l. init_weight (rng, l. dims, l. dims)
539
+ weight = l. init_weight (rng, l. dims, l. dims, l . num_layers )
541
540
return (; gru, weight)
542
541
end
543
542
544
- LuxCore. parameterlength (l:: GatedGraphConv ) = parameterlength (l. gru) + l. dims^ 2
543
+ LuxCore. parameterlength (l:: GatedGraphConv ) = parameterlength (l. gru) + l. dims^ 2 * l . num_layers
545
544
546
- function LuxCore. initialstates (rng:: AbstractRNG , l:: GatedGraphConv )
547
- return (; gru = LuxCore. initialstates (rng, l. gru))
548
- end
549
-
550
- LuxCore. statelength (l:: GatedGraphConv ) = statelength (l. gru)
551
545
552
- function (l:: GatedGraphConv )(g, H, ps, st)
553
- GNNlib. gated_graph_conv (l, g, H)
546
+ function (l:: GatedGraphConv )(g, x, ps, st)
547
+ gru = StatefulLuxLayer {true} (l. gru, ps. gru, _getstate (st, :gru ))
548
+ fgru = (h, x) -> gru ((x, (h,))) # make the forward compatible with Flux.GRUCell style
549
+ m = (; gru= fgru, ps. weight, l. num_layers, l. aggr, l. dims)
550
+ return GNNlib. gated_graph_conv (m, g, x), st
554
551
end
555
552
556
553
function Base. show (io:: IO , l:: GatedGraphConv )
557
554
print (io, " GatedGraphConv($(l. dims) , $(l. num_layers) " )
558
555
print (io, " , aggr=" , l. aggr)
559
556
print (io, " )" )
560
- end
557
+ end
558
+
559
+ @concrete struct GINConv <: GNNContainerLayer{(:nn,)}
560
+ nn <: AbstractExplicitLayer
561
+ ϵ <: Real
562
+ aggr
563
+ end
564
+
565
+ GINConv (nn, ϵ; aggr = + ) = GINConv (nn, ϵ, aggr)
566
+
567
+ function (l:: GINConv )(g, x, ps, st)
568
+ nn = StatefulLuxLayer {true} (l. nn, ps, st)
569
+ m = (; nn, l. ϵ, l. aggr)
570
+ y = GNNlib. gin_conv (m, g, x)
571
+ stnew = _getstate (nn)
572
+ return y, stnew
573
+ end
574
+
575
+ function Base. show (io:: IO , l:: GINConv )
576
+ print (io, " GINConv($(l. nn) " )
577
+ print (io, " , $(l. ϵ) " )
578
+ print (io, " )" )
579
+ end
0 commit comments