@@ -49,15 +49,14 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
49
49
end
50
50
51
51
compute_message (l:: GCNConv , xi, xj, eij) = xj
52
- update_node (l:: GCNConv , m, x) = m
53
52
54
53
function (l:: GCNConv )(g:: GNNGraph , x:: CuMatrix{T} ) where T
55
54
if l. add_self_loops
56
55
g = add_self_loops (g)
57
56
end
58
57
c = 1 ./ sqrt .(degree (g, T, dir= :in ))
59
58
x = x .* c'
60
- x, _ = propagate (l, g, + , x)
59
+ x = propagate (l, g, + , xj = x)
61
60
x = x .* c'
62
61
return l. σ .(l. weight * x .+ l. bias)
63
62
end
@@ -182,12 +181,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity; aggr=+,
182
181
end
183
182
184
183
compute_message (l:: GraphConv , x_i, x_j, e_ij) = x_j
185
- update_node (l:: GraphConv , m, x) = l. σ .(l. weight1 * x .+ l. weight2 * m .+ l. bias)
186
184
187
185
function (l:: GraphConv )(g:: GNNGraph , x:: AbstractMatrix )
188
186
check_num_nodes (g, x)
189
- x, _ = propagate (l, g, l. aggr, x)
190
- x
187
+ m = propagate (l, g, l. aggr, xj= x)
188
+ x = l. σ .(l. weight1 * x .+ l. weight2 * m .+ l. bias)
189
+ return x
191
190
end
192
191
193
192
function Base. show (io:: IO , l:: GraphConv )
@@ -264,8 +263,6 @@ function compute_message(l::GATConv, Wxi, Wxj)
264
263
return (α = α, m = α .* Wxj)
265
264
end
266
265
267
- update_node (l:: GATConv , d̄, x) = d̄. m ./ d̄. α
268
-
269
266
function (l:: GATConv )(g:: GNNGraph , x:: AbstractMatrix )
270
267
check_num_nodes (g, x)
271
268
g = add_self_loops (g)
@@ -275,7 +272,8 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
275
272
Wx = l. weight * x
276
273
Wx = reshape (Wx, chout, heads, :) # chout × nheads × nnodes
277
274
278
- x, _ = propagate (l, g, + , Wx) # # chout × nheads × nnodes
275
+ d̄ = propagate (l, g, + ; x= Wx) # # chout × nheads × nnodes
276
+ x = d̄. m ./ d̄. α
279
277
280
278
if ! l. concat
281
279
x = mean (x, dims= 2 )
@@ -302,7 +300,7 @@ Gated graph convolution layer from [Gated Graph Sequence Neural Networks](https:
302
300
303
301
Implements the recursion
304
302
```math
305
- \m athbf{h}^{(0)}_i = \m athbf{x}_i || \m athbf{0} \\
303
+ \m athbf{h}^{(0)}_i = [ \m athbf{x}_i || \m athbf{0}] \\
306
304
\m athbf{h}^{(l)}_i = GRU(\m athbf{h}^{(l-1)}_i, \s quare_{j \i n N(i)} W \m athbf{h}^{(l-1)}_j)
307
305
```
308
306
@@ -325,17 +323,14 @@ end
325
323
326
324
@functor GatedGraphConv
327
325
328
-
329
326
function GatedGraphConv (out_ch:: Int , num_layers:: Int ;
330
327
aggr= + , init= glorot_uniform)
331
328
w = init (out_ch, out_ch, num_layers)
332
329
gru = GRUCell (out_ch, out_ch)
333
330
GatedGraphConv (w, gru, out_ch, num_layers, aggr)
334
331
end
335
332
336
-
337
333
compute_message (l:: GatedGraphConv , x_i, x_j, e_ij) = x_j
338
- update_node (l:: GatedGraphConv , m, x) = m
339
334
340
335
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
341
336
@non_differentiable fill! (x... )
@@ -350,7 +345,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
350
345
end
351
346
for i = 1 : l. num_layers
352
347
M = view (l. weight, :, :, i) * H
353
- M, _ = propagate (l, g, l. aggr, M)
348
+ M = propagate (l, g, l. aggr; xj = M)
354
349
H, _ = l. gru (H, M)
355
350
end
356
351
H
@@ -391,12 +386,10 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
391
386
392
387
compute_message (l:: EdgeConv , x_i, x_j, e_ij) = l. nn (vcat (x_i, x_j .- x_i))
393
388
394
- update_node (l:: EdgeConv , m, x) = m
395
-
396
- function (l:: EdgeConv )(g:: GNNGraph , X:: AbstractMatrix )
389
+ function (l:: EdgeConv )(g:: GNNGraph , x:: AbstractMatrix )
397
390
check_num_nodes (g, X)
398
- X, _ = propagate (l, g, l. aggr, X )
399
- X
391
+ x = propagate (l, g, l. aggr; x )
392
+ return x
400
393
end
401
394
402
395
function Base. show (io:: IO , l:: EdgeConv )
@@ -433,25 +426,21 @@ Flux.trainable(l::GINConv) = (l.nn,)
433
426
434
427
GINConv (nn, ϵ; aggr= + ) = GINConv (nn, ϵ, aggr)
435
428
436
-
437
429
compute_message (l:: GINConv , x_i, x_j, e_ij) = x_j
438
- update_node (l:: GINConv , m, x) = l. nn ((1 + ofeltype (x, l. ϵ)) * x + m)
439
430
440
- function (l:: GINConv )(g:: GNNGraph , X :: AbstractMatrix )
441
- check_num_nodes (g, X )
442
- X, _ = propagate (l, g, l. aggr, X )
443
- X
431
+ function (l:: GINConv )(g:: GNNGraph , x :: AbstractMatrix )
432
+ check_num_nodes (g, x )
433
+ m = propagate (l, g, l. aggr, xj = x )
434
+ l . nn (( 1 + ofeltype (x, l . ϵ)) * x + m)
444
435
end
445
436
446
-
447
437
function Base. show (io:: IO , l:: GINConv )
448
438
print (io, " GINConv($(l. nn) " )
449
439
print (io, " , $(l. ϵ) " )
450
440
print (io, " )" )
451
441
end
452
442
453
443
454
-
455
444
@doc raw """
456
445
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
457
446
@@ -499,21 +488,17 @@ function NNConv(ch::Pair{Int,Int}, nn, σ=identity; aggr=+, bias=true, init=glor
499
488
end
500
489
501
490
function compute_message (l:: NNConv , x_i, x_j, e_ij)
502
- nin, nedges = size (x_i )
491
+ nin, nedges = size (x_j )
503
492
W = reshape (l. nn (e_ij), (:, nin, nedges))
504
493
x_j = reshape (x_j, (nin, 1 , nedges)) # needed by batched_mul
505
494
m = NNlib. batched_mul (W, x_j)
506
495
return reshape (m, :, nedges)
507
496
end
508
497
509
- function update_node (l:: NNConv , m, x)
510
- l. σ .(l. weight* x .+ m .+ l. bias)
511
- end
512
-
513
498
function (l:: NNConv )(g:: GNNGraph , x:: AbstractMatrix , e)
514
499
check_num_nodes (g, x)
515
- x, _ = propagate (l, g, l. aggr, x, e)
516
- return x
500
+ m = propagate (l, g, l. aggr, xj = x, e = e)
501
+ return l . σ .(l . weight * x .+ m .+ l . bias)
517
502
end
518
503
519
504
(l:: NNConv )(g:: GNNGraph ) = GNNGraph (g, ndata= l (g, node_features (g), edge_features (g)))
@@ -533,7 +518,7 @@ GraphSAGE convolution layer from paper [Inductive Representation Learning on Lar
533
518
534
519
Performs:
535
520
```math
536
- \m athbf{x}_i' = W [\m athbf{x}_i \,\|\, \s quare_{j \i n \m athcal{N}(i)} \m athbf{x}_j]
521
+ \m athbf{x}_i' = W \c dot [\m athbf{x}_i \,\|\, \s quare_{j \i n \m athcal{N}(i)} \m athbf{x}_j]
537
522
```
538
523
539
524
where the aggregation type is selected by `aggr`.
@@ -565,12 +550,12 @@ function SAGEConv(ch::Pair{Int,Int}, σ=identity; aggr=mean,
565
550
end
566
551
567
552
compute_message (l:: SAGEConv , x_i, x_j, e_ij) = x_j
568
- update_node (l:: SAGEConv , m, x) = l. σ .(l. weight * vcat (x, m) .+ l. bias)
569
553
570
554
function (l:: SAGEConv )(g:: GNNGraph , x:: AbstractMatrix )
571
555
check_num_nodes (g, x)
572
- x, _ = propagate (l, g, l. aggr, x)
573
- x
556
+ m = propagate (l, g, l. aggr, xj= x)
557
+ x = l. σ .(l. weight * vcat (x, m) .+ l. bias)
558
+ return x
574
559
end
575
560
576
561
function Base. show (io:: IO , l:: SAGEConv )
@@ -633,21 +618,18 @@ function compute_message(l::ResGatedGraphConv, di, dj)
633
618
return η .* dj. Vx
634
619
end
635
620
636
- update_node (l:: ResGatedGraphConv , m, x) = m
637
-
638
621
function (l:: ResGatedGraphConv )(g:: GNNGraph , x:: AbstractMatrix )
639
622
check_num_nodes (g, x)
640
623
641
624
Ax = l. A * x
642
625
Bx = l. B * x
643
626
Vx = l. V * x
644
627
645
- m, _ = propagate (l, g, + , (; Ax, Bx, Vx))
628
+ m = propagate (l, g, + , xi = (; Ax), xj = (; Bx, Vx))
646
629
647
630
return l. σ .(l. U* x .+ m .+ l. bias)
648
631
end
649
632
650
-
651
633
function Base. show (io:: IO , l:: ResGatedGraphConv )
652
634
out_channel, in_channel = size (l. A)
653
635
print (io, " ResGatedGraphConv(" , in_channel, " =>" , out_channel)
0 commit comments