@@ -628,3 +628,74 @@ function Base.show(io::IO, l::GINConv)
628
628
print (io, " , $(l. ϵ) " )
629
629
print (io, " )" )
630
630
end
631
+
632
+ @concrete struct NNConv <: GNNContainerLayer{(:nn,)}
633
+ nn <: AbstractExplicitLayer
634
+ aggr
635
+ in_dims:: Int
636
+ out_dims:: Int
637
+ use_bias:: Bool
638
+ add_self_loops:: Bool
639
+ use_edge_weight:: Bool
640
+ init_weight
641
+ init_bias
642
+ σ
643
+ end
644
+
645
+ """
646
+ function NNConv(ch::Pair{Int, Int}, σ = identity;
647
+ init_weight = glorot_uniform,
648
+ init_bias = zeros32,
649
+ use_bias::Bool = true,
650
+ add_self_loops::Bool = true,
651
+ use_edge_weight::Bool = false,
652
+ allow_fast_activation::Bool = true)
653
+ """
654
+ # fix args order
655
+ function NNConv (ch:: Pair{Int, Int} , nn, σ = identity;
656
+ aggr = + ,
657
+ init_bias = zeros32,
658
+ use_bias:: Bool = true ,
659
+ init_weight = glorot_uniform,
660
+ add_self_loops:: Bool = true ,
661
+ use_edge_weight:: Bool = false ,
662
+ allow_fast_activation:: Bool = true )
663
+ in_dims, out_dims = ch
664
+ σ = allow_fast_activation ? NNlib. fast_act (σ) : σ
665
+ return NNConv (nn, aggr, in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
666
+ end
667
+
668
+ function (l:: GCNConv )(g, x, edge_weight, ps, st)
669
+ nn = StatefulLuxLayer {true} (l. nn, ps, st)
670
+
671
+ # what would be the order of args here?
672
+ m = (; nn, l. aggr, ps. weight, bias = _getbias (ps),
673
+ l. add_self_loops, l. use_edge_weight, l. σ)
674
+ y = GNNlib. nn_conv (m, g, x, edge_weight)
675
+ stnew = _getstate (nn)
676
+ return y, stnew
677
+ end
678
+
679
+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: NNConv )
680
+ weight = l. init_weight (rng, l. out_dims, l. in_dims)
681
+ if l. use_bias
682
+ bias = l. init_bias (rng, l. out_dims)
683
+ return (; weight, bias)
684
+ else
685
+ return (; weight)
686
+ end
687
+ end
688
+
689
+ LuxCore. parameterlength (l:: NNConv ) = l. use_bias ? l. in_dims * l. out_dims + l. out_dims : l. in_dims * l. out_dims # nn wont affect this right?
690
+ LuxCore. outputsize (d:: NNConv ) = (d. out_dims,)
691
+
692
+
693
+ function Base. show (io:: IO , l:: GINConv )
694
+ print (io, " NNConv($(l. nn) " )
695
+ print (io, " , $(l. ϵ) " )
696
+ l. σ == identity || print (io, " , " , l. σ)
697
+ l. use_bias || print (io, " , use_bias=false" )
698
+ l. add_self_loops || print (io, " , add_self_loops=false" )
699
+ ! l. use_edge_weight || print (io, " , use_edge_weight=true" )
700
+ print (io, " )" )
701
+ end
0 commit comments