Skip to content

Commit 866d20c

Browse files
committed
Add SGConv
1 parent 0681b54 commit 866d20c

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,63 @@ function Base.show(io::IO, l::GATv2Conv)
10771077
print(io, ")")
10781078
end
10791079

1080+
@doc raw"""
1081+
SGConv(int => out, k = 1; init_weight = glorot_uniform, init_bias = zeros32, use_bias = true, add_self_loops = true,use_edge_weight = false)
1082+
1083+
SGC layer from [Simplifying Graph Convolutional Networks](https://arxiv.org/pdf/1902.07153.pdf)
1084+
Performs operation
1085+
```math
1086+
H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta
1087+
```
1088+
where ``\tilde{A}`` is ``A + I``.
1089+
1090+
# Arguments
1091+
1092+
- `in`: Number of input features.
1093+
- `out`: Number of output features.
1094+
- `k` : Number of hops k. Default `1`.
1095+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
1096+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
1097+
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
1098+
- `init_weight`: Weights' initializer. Default `glorot_uniform`.
1099+
- `init_bias`: Bias initializer. Default `zeros32`.
1100+
- `use_bias`: Add learnable bias. Default `true`.
1101+
1102+
1103+
# Examples
1104+
1105+
```julia
1106+
using GNNLux, Lux, Random
1107+
1108+
# initialize random number generator
1109+
rng = Random.default_rng()
1110+
1111+
# create data
1112+
s = [1,1,2,3]
1113+
t = [2,3,1,1]
1114+
g = GNNGraph(s, t)
1115+
x = randn(rng, Float32, 3, g.num_nodes)
1116+
1117+
# create layer
1118+
l = SGConv(3 => 5; add_self_loops = true)
1119+
1120+
# setup layer
1121+
ps, st = LuxCore.setup(rng, l)
1122+
1123+
# forward pass
1124+
y, st = l(g, x, ps, st) # size: 5 × num_nodes
1125+
1126+
# convolution with edge weights
1127+
w = [1.1, 0.1, 2.3, 0.5]
1128+
y = l(g, x, w, ps, st)
1129+
1130+
# Edge weights can also be embedded in the graph.
1131+
g = GNNGraph(s, t, w)
1132+
l = SGConv(3 => 5, add_self_loops = true, use_edge_weight=true)
1133+
ps, st = LuxCore.setup(rng, l)
1134+
y, st = l(g, x, ps, st) # same as l(g, x, w)
1135+
```
1136+
"""
10801137
@concrete struct SGConv <: GNNLayer
10811138
in_dims::Int
10821139
out_dims::Int

0 commit comments

Comments
 (0)