@@ -515,4 +515,46 @@ function Base.show(io::IO, l::GATv2Conv)
515
515
l. σ == identity || print (io, " , " , l. σ)
516
516
print (io, " , negative_slope=" , l. negative_slope)
517
517
print (io, " )" )
518
+ end
519
+
520
+
521
+ @concrete struct GatedGraphConv <: GRULayer
522
+ gru
523
+ init_weight
524
+ dims:: Int
525
+ num_layers:: Int
526
+ aggr
527
+ end
528
+
529
+
530
+ function GatedGraphConv (dims:: Int , num_layers:: Int ;
531
+ aggr = + , init_weight = glorot_uniform)
532
+ gru = GRUCell (dims => dims)
533
+ return GatedGraphConv (gru, init_weight, dims, num_layers, aggr)
534
+ end
535
+
536
+ LucCore. outputsize (l:: GatedGraphConv ) = (l. dims,)
537
+
538
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: GatedGraphConv )
539
+ gru = LuxCore. initialparameters (rng, l. gru)
540
+ weight = l. init_weight (rng, l. dims, l. dims)
541
+ return (; gru, weight)
542
+ end
543
+
544
+ LuxCore. parameterlength (l:: GatedGraphConv ) = parameterlength (l. gru) + l. dims^ 2
545
+
546
+ function LuxCore. initialstates (rng:: AbstractRNG , l:: GatedGraphConv )
547
+ return (; gru = LuxCore. initialstates (rng, l. gru))
548
+ end
549
+
550
+ LuxCore. statelength (l:: GatedGraphConv ) = statelength (l. gru)
551
+
552
+ function (l:: GatedGraphConv )(g, H, ps, st)
553
+ GNNlib. gated_graph_conv (l, g, H)
554
+ end
555
+
556
+ function Base. show (io:: IO , l:: GatedGraphConv )
557
+ print (io, " GatedGraphConv($(l. dims) , $(l. num_layers) " )
558
+ print (io, " , aggr=" , l. aggr)
559
+ print (io, " )" )
518
560
end
0 commit comments