@@ -39,26 +39,22 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
39
39
GCNConv (W, b, σ, add_self_loops)
40
40
end
41
41
42
- # # Matrix operations are more performant,
43
- # # but cannot compute the normalized adjacency of sparse cuda matrices yet,
44
- # # therefore fallback to message passing framework on gpu for the time being
45
-
46
42
function (l:: GCNConv )(g:: GNNGraph , x:: AbstractMatrix{T} ) where T
47
- Ã = normalized_adjacency (g, T; dir= :out , l. add_self_loops)
48
- l. σ .(l. weight * x * Ã .+ l. bias)
49
- end
50
-
51
- compute_message (l:: GCNConv , xi, xj, eij) = xj
52
-
53
- function (l:: GCNConv )(g:: GNNGraph , x:: CuMatrix{T} ) where T
54
43
if l. add_self_loops
55
44
g = add_self_loops (g)
56
45
end
46
+ Dout, Din = size (l. weight)
47
+ if Dout < Din
48
+ x = l. weight * x
49
+ end
57
50
c = 1 ./ sqrt .(degree (g, T, dir= :in ))
58
51
x = x .* c'
59
- x = propagate (l , g, + , xj= x)
52
+ x = propagate (copyxj , g, + , xj= x)
60
53
x = x .* c'
61
- return l. σ .(l. weight * x .+ l. bias)
54
+ if Dout >= Din
55
+ x = l. weight * x
56
+ end
57
+ return l. σ .(x .+ l. bias)
62
58
end
63
59
64
60
function Base. show (io:: IO , l:: GCNConv )
@@ -180,11 +176,9 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity; aggr=+,
180
176
GraphConv (W1, W2, b, σ, aggr)
181
177
end
182
178
183
- compute_message (l:: GraphConv , x_i, x_j, e_ij) = x_j
184
-
185
179
function (l:: GraphConv )(g:: GNNGraph , x:: AbstractMatrix )
186
180
check_num_nodes (g, x)
187
- m = propagate (l , g, l. aggr, xj= x)
181
+ m = propagate (copyxj , g, l. aggr, xj= x)
188
182
x = l. σ .(l. weight1 * x .+ l. weight2 * m .+ l. bias)
189
183
return x
190
184
end
@@ -272,7 +266,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
272
266
Wx = l. weight * x
273
267
Wx = reshape (Wx, chout, heads, :) # chout × nheads × nnodes
274
268
275
- d̄ = propagate (l, g, + ; x = Wx) # # chout × nheads × nnodes
269
+ d̄ = propagate (l, g, + ; xi = Wx, xj = Wx) # # chout × nheads × nnodes
276
270
x = d̄. m ./ d̄. α
277
271
278
272
if ! l. concat
@@ -330,8 +324,6 @@ function GatedGraphConv(out_ch::Int, num_layers::Int;
330
324
GatedGraphConv (w, gru, out_ch, num_layers, aggr)
331
325
end
332
326
333
- compute_message (l:: GatedGraphConv , x_i, x_j, e_ij) = x_j
334
-
335
327
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
336
328
@non_differentiable fill! (x... )
337
329
@@ -345,7 +337,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
345
337
end
346
338
for i = 1 : l. num_layers
347
339
M = view (l. weight, :, :, i) * H
348
- M = propagate (l , g, l. aggr; xj= M)
340
+ M = propagate (copyxj , g, l. aggr; xj= M)
349
341
H, _ = l. gru (H, M)
350
342
end
351
343
H
@@ -387,8 +379,8 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
387
379
compute_message (l:: EdgeConv , x_i, x_j, e_ij) = l. nn (vcat (x_i, x_j .- x_i))
388
380
389
381
function (l:: EdgeConv )(g:: GNNGraph , x:: AbstractMatrix )
390
- check_num_nodes (g, X )
391
- x = propagate (l, g, l. aggr; x)
382
+ check_num_nodes (g, x )
383
+ x = propagate (l, g, l. aggr, xi = x, xj = x)
392
384
return x
393
385
end
394
386
@@ -426,11 +418,9 @@ Flux.trainable(l::GINConv) = (l.nn,)
426
418
427
419
GINConv (nn, ϵ; aggr= + ) = GINConv (nn, ϵ, aggr)
428
420
429
- compute_message (l:: GINConv , x_i, x_j, e_ij) = x_j
430
-
431
421
function (l:: GINConv )(g:: GNNGraph , x:: AbstractMatrix )
432
422
check_num_nodes (g, x)
433
- m = propagate (l , g, l. aggr, xj= x)
423
+ m = propagate (copyxj , g, l. aggr, xj= x)
434
424
l. nn ((1 + ofeltype (x, l. ϵ)) * x + m)
435
425
end
436
426
@@ -549,11 +539,9 @@ function SAGEConv(ch::Pair{Int,Int}, σ=identity; aggr=mean,
549
539
SAGEConv (W, b, σ, aggr)
550
540
end
551
541
552
- compute_message (l:: SAGEConv , x_i, x_j, e_ij) = x_j
553
-
554
542
function (l:: SAGEConv )(g:: GNNGraph , x:: AbstractMatrix )
555
543
check_num_nodes (g, x)
556
- m = propagate (l , g, l. aggr, xj= x)
544
+ m = propagate (copyxj , g, l. aggr, xj= x)
557
545
x = l. σ .(l. weight * vcat (x, m) .+ l. bias)
558
546
return x
559
547
end
@@ -613,9 +601,9 @@ function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity;
613
601
return ResGatedGraphConv (A, B, U, V, b, σ)
614
602
end
615
603
616
- function compute_message (l:: ResGatedGraphConv , di, dj )
617
- η = sigmoid .(di . Ax .+ dj . Bx)
618
- return η .* dj . Vx
604
+ function compute_message (l:: ResGatedGraphConv , xi, xj, e )
605
+ η = sigmoid .(xi . Ax .+ xj . Bx)
606
+ η .* xj . Vx
619
607
end
620
608
621
609
function (l:: ResGatedGraphConv )(g:: GNNGraph , x:: AbstractMatrix )
0 commit comments