@@ -1771,6 +1771,54 @@ function Base.show(io::IO, l::ResGatedGraphConv)
17711771    print (io, " )" 
17721772end 
17731773
1774+ @doc  raw """ 
1775+     SAGEConv(in => out, σ=identity; aggr=mean, init_weight = glorot_uniform, init_bias = zeros32, use_bias=true) 
1776+      
1777+ GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf). 
1778+ 
1779+ Performs: 
1780+ ```math 
1781+ \m athbf{x}_i' = W \c dot [\m athbf{x}_i; \s quare_{j \i n \m athcal{N}(i)} \m athbf{x}_j]
1782+ ``` 
1783+ 
1784+ where the aggregation type is selected by `aggr`. 
1785+ 
1786+ # Arguments 
1787+ 
1788+ - `in`: The dimension of input features. 
1789+ - `out`: The dimension of output features. 
1790+ - `σ`: Activation function. 
1791+ - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). 
1792+ - `init_bias`: Bias initializer. Default `zeros32`. 
1793+ - `use_bias`: Add learnable bias. Default `true`. 
1794+ 
1795+ 
1796+ # Examples: 
1797+ 
1798+ ```julia 
1799+ using GNNLux, Lux, Random 
1800+ 
1801+ # initialize random number generator 
1802+ rng = Random.default_rng() 
1803+ 
1804+ # create data 
1805+ s = [1,1,2,3] 
1806+ t = [2,3,1,1] 
1807+ in_channel = 3 
1808+ out_channel = 5 
1809+ g = GNNGraph(s, t) 
1810+ x = rand(rng, Float32, in_channel, g.num_nodes) 
1811+ 
1812+ # create layer 
1813+ l = SAGEConv(in_channel => out_channel, tanh, use_bias = false, aggr = +) 
1814+ 
1815+ # setup layer 
1816+ ps, st = LuxCore.setup(rng, l) 
1817+ 
1818+ # forward pass 
1819+ y, st = l(g, x, ps, st)       # size:  out_channel × num_nodes    
1820+ ``` 
1821+ """ 
17741822@concrete  struct  SAGEConv <:  GNNLayer 
17751823    in_dims:: Int 
17761824    out_dims:: Int 
0 commit comments