Skip to content

Commit ffebc07

Browse files
committed
Add SAGEConv docs
1 parent b3d7462 commit ffebc07

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,54 @@ function Base.show(io::IO, l::ResGatedGraphConv)
17711771
print(io, ")")
17721772
end
17731773

1774+
@doc raw"""
1775+
SAGEConv(in => out, σ=identity; aggr=mean, init_weight = glorot_uniform, init_bias = zeros32, use_bias=true)
1776+
1777+
GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf).
1778+
1779+
Performs:
1780+
```math
1781+
\mathbf{x}_i' = W \cdot [\mathbf{x}_i; \square_{j \in \mathcal{N}(i)} \mathbf{x}_j]
1782+
```
1783+
1784+
where the aggregation type is selected by `aggr`.
1785+
1786+
# Arguments
1787+
1788+
- `in`: The dimension of input features.
1789+
- `out`: The dimension of output features.
1790+
- `σ`: Activation function.
1791+
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
1792+
- `init_bias`: Bias initializer. Default `zeros32`.
1793+
- `use_bias`: Add learnable bias. Default `true`.
1794+
1795+
1796+
# Examples:
1797+
1798+
```julia
1799+
using GNNLux, Lux, Random
1800+
1801+
# initialize random number generator
1802+
rng = Random.default_rng()
1803+
1804+
# create data
1805+
s = [1,1,2,3]
1806+
t = [2,3,1,1]
1807+
in_channel = 3
1808+
out_channel = 5
1809+
g = GNNGraph(s, t)
1810+
x = rand(rng, Float32, in_channel, g.num_nodes)
1811+
1812+
# create layer
1813+
l = SAGEConv(in_channel => out_channel, tanh, use_bias = false, aggr = +)
1814+
1815+
# setup layer
1816+
ps, st = LuxCore.setup(rng, l)
1817+
1818+
# forward pass
1819+
y, st = l(g, x, ps, st) # size: out_channel × num_nodes
1820+
```
1821+
"""
17741822
@concrete struct SAGEConv <: GNNLayer
17751823
in_dims::Int
17761824
out_dims::Int

0 commit comments

Comments
 (0)