@@ -1181,3 +1181,120 @@ function Base.show(io::IO, l::GMMConv)
1181
1181
l. residual== true || print (io, " , residual=" , l. residual)
1182
1182
print (io, " )" )
1183
1183
end
1184
+
1185
+ @doc raw """
1186
+ SGConv(int => out, k=1; [bias, init, add_self_loops, use_edge_weight])
1187
+
1188
+ SGC layer from [Simplifying Graph Convolutional Networks](https://arxiv.org/pdf/1902.07153.pdf)
1189
+ Performs operation
1190
+ ```math
1191
+ H^{K} = (\t ilde{D}^{-1/2} \t ilde{A} \t ilde{D}^{-1/2})^K X \T heta
1192
+ ```
1193
+ where ``\t ilde{A}`` is ``A + I``.
1194
+
1195
+ # Arguments
1196
+
1197
+ - `in`: Number of input features.
1198
+ - `out`: Number of output features.
1199
+ - `k` : Number of hops k. Default `1`.
1200
+ - `bias`: Add learnable bias. Default `true`.
1201
+ - `init`: Weights' initializer. Default `glorot_uniform`.
1202
+ - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`.
1203
+ - `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available).
1204
+ If `add_self_loops=true` the new weights will be set to 1. Default `false`.
1205
+
1206
+ # Examples
1207
+
1208
+ ```julia
1209
+ # create data
1210
+ s = [1,1,2,3]
1211
+ t = [2,3,1,1]
1212
+ g = GNNGraph(s, t)
1213
+ x = randn(3, g.num_nodes)
1214
+
1215
+ # create layer
1216
+ l = SGConv(3 => 5; add_self_loops = true)
1217
+
1218
+ # forward pass
1219
+ y = l(g, x) # size: 5 × num_nodes
1220
+
1221
+ # convolution with edge weights
1222
+ w = [1.1, 0.1, 2.3, 0.5]
1223
+ y = l(g, x, w)
1224
+
1225
+ # Edge weights can also be embedded in the graph.
1226
+ g = GNNGraph(s, t, w)
1227
+ l = SGConv(3 => 5, add_self_loops = true, use_edge_weight=true)
1228
+ y = l(g, x) # same as l(g, x, w)
1229
+ ```
1230
+ """
1231
+ struct SGConv{A<: AbstractMatrix , B} <: GNNLayer
1232
+ weight:: A
1233
+ bias:: B
1234
+ k:: Int
1235
+ add_self_loops:: Bool
1236
+ use_edge_weight:: Bool
1237
+ end
1238
+
1239
+ @functor SGConv
1240
+
1241
+ function SGConv (ch:: Pair{Int,Int} , k= 1 ;
1242
+ init= glorot_uniform,
1243
+ bias:: Bool = true ,
1244
+ add_self_loops= true ,
1245
+ use_edge_weight= false )
1246
+ in, out = ch
1247
+ W = init (out, in)
1248
+ b = bias ? Flux. create_bias (W, true , out) : false
1249
+ SGConv (W, b, k, add_self_loops, use_edge_weight)
1250
+ end
1251
+
1252
+ function (l:: SGConv )(g:: GNNGraph , x:: AbstractMatrix{T} , edge_weight:: EW = nothing ) where
1253
+ {T, EW<: Union{Nothing,AbstractVector} }
1254
+ @assert ! (g isa GNNGraph{<: ADJMAT_T } && edge_weight != = nothing ) " Providing external edge_weight is not yet supported for adjacency matrix graphs"
1255
+
1256
+ if edge_weight != = nothing
1257
+ @assert length (edge_weight) == g. num_edges " Wrong number of edge weights (expected $(g. num_edges) but given $(length (edge_weight)) )"
1258
+ end
1259
+
1260
+ if l. add_self_loops
1261
+ g = add_self_loops (g)
1262
+ if edge_weight != = nothing
1263
+ edge_weight = [edge_weight; fill! (similar (edge_weight, g. num_nodes), 1 )]
1264
+ @assert length (edge_weight) == g. num_edges
1265
+ end
1266
+ end
1267
+ Dout, Din = size (l. weight)
1268
+ if Dout < Din
1269
+ x = l. weight * x
1270
+ end
1271
+ d = degree (g, T; dir= :in , edge_weight)
1272
+ c = 1 ./ sqrt .(d)
1273
+ for iter in 1 : l. k
1274
+ x = x .* c'
1275
+ if edge_weight != = nothing
1276
+ x = propagate (e_mul_xj, g, + , xj= x, e= edge_weight)
1277
+ elseif l. use_edge_weight
1278
+ x = propagate (w_mul_xj, g, + , xj= x)
1279
+ else
1280
+ x = propagate (copy_xj, g, + , xj= x)
1281
+ end
1282
+ x = x .* c'
1283
+ end
1284
+ if Dout >= Din
1285
+ x = l. weight * x
1286
+ end
1287
+ return (x .+ l. bias)
1288
+ end
1289
+
1290
+ function (l:: SGConv )(g:: GNNGraph{<:ADJMAT_T} , x:: AbstractMatrix , edge_weight:: AbstractVector )
1291
+ g = GNNGraph (edge_index (g)... ; g. num_nodes)
1292
+ return l (g, x, edge_weight)
1293
+ end
1294
+
1295
+ function Base. show (io:: IO , l:: SGConv )
1296
+ out, in = size (l. weight)
1297
+ print (io, " SGConv($in => $out " )
1298
+ l. k == 1 || print (io, " , " , l. k)
1299
+ print (io, " )" )
1300
+ end
0 commit comments