142
142
143
143
144
144
@doc raw """
145
- GraphConv(in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
145
+ GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
146
146
147
147
Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
148
148
172
172
173
173
@functor GraphConv
174
174
175
- function GraphConv (ch:: Pair{Int,Int} , σ= identity, aggr= + ;
175
+ function GraphConv (ch:: Pair{Int,Int} , σ= identity; aggr= + ,
176
176
init= glorot_uniform, bias:: Bool = true )
177
177
in, out = ch
178
178
W1 = init (out, in)
@@ -214,9 +214,9 @@ Implements the operation
214
214
```math
215
215
\m athbf{x}_i' = \s um_{j \i n N(i)} \a lpha_{ij} W \m athbf{x}_j
216
216
```
217
- where the attention coefficient ``\a lpha_{ij}`` is given by
217
+ where the attention coefficients ``\a lpha_{ij}`` are given by
218
218
```math
219
- \a lpha_{ij} = \f rac{1}{z_i} \e xp(LeakyReLU(\m athbf{a}^T [W \m athbf{x}_i || W \m athbf{x}_j]))
219
+ \a lpha_{ij} = \f rac{1}{z_i} \e xp(LeakyReLU(\m athbf{a}^T [W \m athbf{x}_i \,\|\, W \m athbf{x}_j]))
220
220
```
221
221
with ``z_i`` a normalization factor.
222
222
@@ -225,9 +225,9 @@ with ``z_i`` a normalization factor.
225
225
- `in`: The dimension of input features.
226
226
- `out`: The dimension of output features.
227
227
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
228
- - `heads`: Number attention heads
228
+ - `heads`: Number attention heads.
229
229
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
230
- - `negative_slope::Real `: Keyword argument, the parameter of LeakyReLU.
230
+ - `negative_slope`: The parameter of LeakyReLU.
231
231
"""
232
232
struct GATConv{T, A<: AbstractMatrix , B} <: GNNLayer
233
233
weight:: A
@@ -248,14 +248,18 @@ function GATConv(ch::Pair{Int,Int}, σ=identity;
248
248
init= glorot_uniform, bias:: Bool = true )
249
249
in, out = ch
250
250
W = init (out* heads, in)
251
- b = bias ? Flux. create_bias (W, true , out* heads) : false
251
+ if concat
252
+ b = bias ? Flux. create_bias (W, true , out* heads) : false
253
+ else
254
+ b = bias ? Flux. create_bias (W, true , out) : false
255
+ end
252
256
a = init (2 * out, heads)
253
257
negative_slope = convert (eltype (W), negative_slope)
254
258
GATConv (W, b, a, σ, negative_slope, ch, heads, concat)
255
259
end
256
260
257
261
function compute_message (l:: GATConv , Wxi, Wxj)
258
- aWW = sum (l. a .* cat (Wxi, Wxj, dims = 1 ), dims= 1 ) # 1 × nheads × nedges
262
+ aWW = sum (l. a .* vcat (Wxi, Wxj), dims= 1 ) # 1 × nheads × nedges
259
263
α = exp .(leakyrelu .(aWW, l. negative_slope))
260
264
return (α = α, m = α .* Wxj)
261
265
end
@@ -273,14 +277,13 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
273
277
274
278
x, _ = propagate (l, g, + , Wx) # # chout × nheads × nnodes
275
279
276
- b = reshape (l. bias, chout, heads)
277
- x = l. σ .(x .+ b)
278
280
if ! l. concat
279
- x = sum (x, dims= 2 )
281
+ x = mean (x, dims= 2 )
280
282
end
283
+ x = reshape (x, :, size (x, 3 )) # return a matrix
284
+ x = l. σ .(x .+ l. bias)
281
285
282
- # We finally return a matrix
283
- return reshape (x, :, size (x, 3 ))
286
+ return x
284
287
end
285
288
286
289
@@ -514,3 +517,60 @@ function Base.show(io::IO, l::NNConv)
514
517
print (io, " , aggr=" , l. aggr)
515
518
print (io, " )" )
516
519
end
520
+
521
+
522
+ @doc raw """
523
+ SAGEConv(in => out, σ=identity; aggr=mean, bias=true, init=glorot_uniform)
524
+
525
+ GraphSAGE convolution layer from paper [Inductive Representation Learning on Large Graphs](https://arxiv.org/pdf/1706.02216.pdf).
526
+
527
+ Performs:
528
+ ```math
529
+ \m athbf{x}_i' = W [\m athbf{x}_i \,\|\, \s quare_{j \i n \m athcal{N}(i)} \m athbf{x}_j]
530
+ ```
531
+
532
+ where the aggregation type is selected by `aggr`.
533
+
534
+ # Arguments
535
+
536
+ - `in`: The dimension of input features.
537
+ - `out`: The dimension of output features.
538
+ - `σ`: Activation function.
539
+ - `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
540
+ - `bias`: Add learnable bias.
541
+ - `init`: Weights' initializer.
542
+ """
543
+ struct SAGEConv{A<: AbstractMatrix , B} <: GNNLayer
544
+ weight:: A
545
+ bias:: B
546
+ σ
547
+ aggr
548
+ end
549
+
550
+ @functor SAGEConv
551
+
552
+ function SAGEConv (ch:: Pair{Int,Int} , σ= identity; aggr= mean,
553
+ init= glorot_uniform, bias:: Bool = true )
554
+ in, out = ch
555
+ W = init (out, 2 * in)
556
+ b = bias ? Flux. create_bias (W, true , out) : false
557
+ SAGEConv (W, b, σ, aggr)
558
+ end
559
+
560
+ compute_message (l:: SAGEConv , x_i, x_j, e_ij) = x_j
561
+ update_node (l:: SAGEConv , m, x) = l. σ .(l. weight * vcat (x, m) .+ l. bias)
562
+
563
+ function (l:: SAGEConv )(g:: GNNGraph , x:: AbstractMatrix )
564
+ check_num_nodes (g, x)
565
+ x, _ = propagate (l, g, l. aggr, x)
566
+ x
567
+ end
568
+
569
+ function Base. show (io:: IO , l:: SAGEConv )
570
+ in_channel = size (l. weight1, ndims (l. weight1))
571
+ out_channel = size (l. weight1, ndims (l. weight1)- 1 )
572
+ print (io, " SAGEConv(" , in_channel, " => " , out_channel)
573
+ l. σ == identity || print (io, " , " , l. σ)
574
+ print (io, " , aggr=" , l. aggr)
575
+ print (io, " )" )
576
+ end
0 commit comments