Skip to content

Commit eb43bc9

Browse files
authored
SGConv (#154)
* Initial skeleton * Added to main file * Small change in .gitignore * Indent fix * Fixed line endings * Fixed trailing whitespace * Trailing whitespaces * Main computation * Added example * Added tests * Small fix
1 parent c7d0afe commit eb43bc9

File tree

4 files changed

+136
-0
lines changed

4 files changed

+136
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
*.jl.*.cov
22
*.jl.cov
33
*.jl.mem
4+
*~
5+
*.swp
6+
*.swo
47
Manifest.toml
58
/docs/build/
69
.vscode

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export
6161
ResGatedGraphConv,
6262
SAGEConv,
6363
GMMConv,
64+
SGConv,
6465

6566
# layers/pool
6667
GlobalPool,

src/layers/conv.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,3 +1181,120 @@ function Base.show(io::IO, l::GMMConv)
11811181
l.residual==true || print(io, ", residual=", l.residual)
11821182
print(io, ")")
11831183
end
1184+
1185+
@doc raw"""
1186+
SGConv(int => out, k=1; [bias, init, add_self_loops, use_edge_weight])
1187+
1188+
SGC layer from [Simplifying Graph Convolutional Networks](https://arxiv.org/pdf/1902.07153.pdf)
1189+
Performs operation
1190+
```math
1191+
H^{K} = (\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})^K X \Theta
1192+
```
1193+
where ``\tilde{A}`` is ``A + I``.
1194+
1195+
# Arguments
1196+
1197+
- `in`: Number of input features.
1198+
- `out`: Number of output features.
1199+
- `k` : Number of hops k. Default `1`.
1200+
- `bias`: Add learnable bias. Default `true`.
1201+
- `init`: Weights' initializer. Default `glorot_uniform`.
1202+
- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
1203+
- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
1204+
If `add_self_loops=true` the new weights will be set to 1. Default `false`.
1205+
1206+
# Examples
1207+
1208+
```julia
1209+
# create data
1210+
s = [1,1,2,3]
1211+
t = [2,3,1,1]
1212+
g = GNNGraph(s, t)
1213+
x = randn(3, g.num_nodes)
1214+
1215+
# create layer
1216+
l = SGConv(3 => 5; add_self_loops = true)
1217+
1218+
# forward pass
1219+
y = l(g, x) # size: 5 × num_nodes
1220+
1221+
# convolution with edge weights
1222+
w = [1.1, 0.1, 2.3, 0.5]
1223+
y = l(g, x, w)
1224+
1225+
# Edge weights can also be embedded in the graph.
1226+
g = GNNGraph(s, t, w)
1227+
l = SGConv(3 => 5, add_self_loops = true, use_edge_weight=true)
1228+
y = l(g, x) # same as l(g, x, w)
1229+
```
1230+
"""
1231+
struct SGConv{A<:AbstractMatrix, B} <: GNNLayer
1232+
weight::A
1233+
bias::B
1234+
k::Int
1235+
add_self_loops::Bool
1236+
use_edge_weight::Bool
1237+
end
1238+
1239+
@functor SGConv
1240+
1241+
function SGConv(ch::Pair{Int,Int}, k=1;
1242+
init=glorot_uniform,
1243+
bias::Bool=true,
1244+
add_self_loops=true,
1245+
use_edge_weight=false)
1246+
in, out = ch
1247+
W = init(out, in)
1248+
b = bias ? Flux.create_bias(W, true, out) : false
1249+
SGConv(W, b, k, add_self_loops, use_edge_weight)
1250+
end
1251+
1252+
function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T}, edge_weight::EW=nothing) where
1253+
{T, EW<:Union{Nothing,AbstractVector}}
1254+
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
1255+
1256+
if edge_weight !== nothing
1257+
@assert length(edge_weight) == g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
1258+
end
1259+
1260+
if l.add_self_loops
1261+
g = add_self_loops(g)
1262+
if edge_weight !== nothing
1263+
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
1264+
@assert length(edge_weight) == g.num_edges
1265+
end
1266+
end
1267+
Dout, Din = size(l.weight)
1268+
if Dout < Din
1269+
x = l.weight * x
1270+
end
1271+
d = degree(g, T; dir=:in, edge_weight)
1272+
c = 1 ./ sqrt.(d)
1273+
for iter in 1:l.k
1274+
x = x .* c'
1275+
if edge_weight !== nothing
1276+
x = propagate(e_mul_xj, g, +, xj=x, e=edge_weight)
1277+
elseif l.use_edge_weight
1278+
x = propagate(w_mul_xj, g, +, xj=x)
1279+
else
1280+
x = propagate(copy_xj, g, +, xj=x)
1281+
end
1282+
x = x .* c'
1283+
end
1284+
if Dout >= Din
1285+
x = l.weight * x
1286+
end
1287+
return (x .+ l.bias)
1288+
end
1289+
1290+
function (l::SGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector)
1291+
g = GNNGraph(edge_index(g)...; g.num_nodes)
1292+
return l(g, x, edge_weight)
1293+
end
1294+
1295+
function Base.show(io::IO, l::SGConv)
1296+
out, in = size(l.weight)
1297+
print(io, "SGConv($in => $out")
1298+
l.k == 1 || print(io, ", ", l.k)
1299+
print(io, ")")
1300+
end

test/layers/conv.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,19 @@
272272
test_layer(l, g, rtol=RTOL_HIGH, outsize = (out_channel, g.num_nodes))
273273
end
274274
end
275+
276+
@testset "SGConv" begin
277+
K = [1, 2, 3] # for different number of hops
278+
for k in K
279+
l = SGConv(in_channel => out_channel, k, add_self_loops = true)
280+
for g in test_graphs
281+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
282+
end
283+
284+
l = SGConv(in_channel => out_channel, k, add_self_loops = true)
285+
for g in test_graphs
286+
test_layer(l, g, rtol=RTOL_HIGH, outsize=(out_channel, g.num_nodes))
287+
end
288+
end
289+
end
275290
end

0 commit comments

Comments
 (0)