@@ -844,3 +844,51 @@ function Base.show(io::IO, l::ResGatedGraphConv)
844844 l. use_bias || print (io, " , use_bias=false" )
845845 print (io, " )" )
846846end
847+
848+ @concrete struct SAGEConv <: GNNLayer
849+ in_dims:: Int
850+ out_dims:: Int
851+ use_bias:: Bool
852+ init_weight
853+ init_bias
854+ σ
855+ aggr
856+ end
857+
858+ function SAGEConv (ch:: Pair{Int, Int} , σ = identity;
859+ aggr = mean,
860+ init_weight = glorot_uniform,
861+ init_bias = zeros32,
862+ use_bias:: Bool = true )
863+ in_dims, out_dims = ch
864+ σ = NNlib. fast_act (σ)
865+ return SAGEConv (in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
866+ end
867+
868+ function LuxCore. initialparameters (rng:: AbstractRNG , l:: SAGEConv )
869+ weight = l. init_weight (rng, l. out_dims, 2 * l. in_dims)
870+ if l. use_bias
871+ bias = l. init_bias (rng, l. out_dims)
872+ return (; weight, bias)
873+ else
874+ return (; weight)
875+ end
876+ end
877+
878+ LuxCore. parameterlength (l:: SAGEConv ) = l. use_bias ? l. out_dims * 2 * l. in_dims + l. out_dims :
879+ l. out_dims * 2 * l. in_dims
880+ LuxCore. outputsize (d:: SAGEConv ) = (d. out_dims,)
881+
882+ function Base. show (io:: IO , l:: SAGEConv )
883+ print (io, " SAGEConv(" , l. in_dims, " => " , l. out_dims)
884+ (l. σ == identity) || print (io, " , " , l. σ)
885+ (l. aggr == mean) || print (io, " , aggr=" , l. aggr)
886+ l. use_bias || print (io, " , use_bias=false" )
887+ print (io, " )" )
888+ end
889+
890+ function (l:: SAGEConv )(g, x, ps, st)
891+ m = (; ps. weight, bias = _getbias (ps),
892+ l. σ, l. aggr)
893+ return GNNlib. sage_conv (m, g, x), st
894+ end
0 commit comments