Skip to content

Commit 4587f7f

Browse files
propagate redisign
1 parent 40154c7 commit 4587f7f

File tree

5 files changed

+54
-48
lines changed

5 files changed

+54
-48
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1313
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
1414
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1617
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1819
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"

src/deprecations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44

55
# TO Deprecate
66
# x, _ = propagate(l, g, l.aggr, x, e)
7+
8+
# # TODO deprecate
9+
# propagate(l, g::GNNGraph, aggr, x, e=nothing) = propagate(l, g, aggr; x, e)
10+

src/layers/conv.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,22 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
3939
GCNConv(W, b, σ, add_self_loops)
4040
end
4141

42-
## Matrix operations are more performant,
43-
## but cannot compute the normalized adjacency of sparse cuda matrices yet,
44-
## therefore fallback to message passing framework on gpu for the time being
45-
4642
function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
47-
= normalized_adjacency(g, T; dir=:out, l.add_self_loops)
48-
l.σ.(l.weight * x *.+ l.bias)
49-
end
50-
51-
compute_message(l::GCNConv, xi, xj, eij) = xj
52-
53-
function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
5443
if l.add_self_loops
5544
g = add_self_loops(g)
5645
end
46+
Dout, Din = size(l.weight)
47+
if Dout < Din
48+
x = l.weight * x
49+
end
5750
c = 1 ./ sqrt.(degree(g, T, dir=:in))
5851
x = x .* c'
59-
x = propagate(l, g, +, xj=x)
52+
x = propagate(copyxj, g, +, xj=x)
6053
x = x .* c'
61-
return l.σ.(l.weight * x .+ l.bias)
54+
if Dout >= Din
55+
x = l.weight * x
56+
end
57+
return l.σ.(x .+ l.bias)
6258
end
6359

6460
function Base.show(io::IO, l::GCNConv)
@@ -180,11 +176,9 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity; aggr=+,
180176
GraphConv(W1, W2, b, σ, aggr)
181177
end
182178

183-
compute_message(l::GraphConv, x_i, x_j, e_ij) = x_j
184-
185179
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
186180
check_num_nodes(g, x)
187-
m = propagate(l, g, l.aggr, xj=x)
181+
m = propagate(copyxj, g, l.aggr, xj=x)
188182
x = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)
189183
return x
190184
end
@@ -272,7 +266,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix)
272266
Wx = l.weight * x
273267
Wx = reshape(Wx, chout, heads, :) # chout × nheads × nnodes
274268

275-
= propagate(l, g, +; x=Wx) ## chout × nheads × nnodes
269+
= propagate(l, g, +; xi=Wx, xj=Wx) ## chout × nheads × nnodes
276270
x =.m ./.α
277271

278272
if !l.concat
@@ -330,8 +324,6 @@ function GatedGraphConv(out_ch::Int, num_layers::Int;
330324
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
331325
end
332326

333-
compute_message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
334-
335327
# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
336328
@non_differentiable fill!(x...)
337329

@@ -345,7 +337,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
345337
end
346338
for i = 1:l.num_layers
347339
M = view(l.weight, :, :, i) * H
348-
M = propagate(l, g, l.aggr; xj=M)
340+
M = propagate(copyxj, g, l.aggr; xj=M)
349341
H, _ = l.gru(H, M)
350342
end
351343
H
@@ -387,8 +379,8 @@ EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
387379
compute_message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))
388380

389381
function (l::EdgeConv)(g::GNNGraph, x::AbstractMatrix)
390-
check_num_nodes(g, X)
391-
x = propagate(l, g, l.aggr; x)
382+
check_num_nodes(g, x)
383+
x = propagate(l, g, l.aggr, xi=x, xj=x)
392384
return x
393385
end
394386

@@ -426,11 +418,9 @@ Flux.trainable(l::GINConv) = (l.nn,)
426418

427419
GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
428420

429-
compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
430-
431421
function (l::GINConv)(g::GNNGraph, x::AbstractMatrix)
432422
check_num_nodes(g, x)
433-
m = propagate(l, g, l.aggr, xj=x)
423+
m = propagate(copyxj, g, l.aggr, xj=x)
434424
l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
435425
end
436426

@@ -549,11 +539,9 @@ function SAGEConv(ch::Pair{Int,Int}, σ=identity; aggr=mean,
549539
SAGEConv(W, b, σ, aggr)
550540
end
551541

552-
compute_message(l::SAGEConv, x_i, x_j, e_ij) = x_j
553-
554542
function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix)
555543
check_num_nodes(g, x)
556-
m = propagate(l, g, l.aggr, xj=x)
544+
m = propagate(copyxj, g, l.aggr, xj=x)
557545
x = l.σ.(l.weight * vcat(x, m) .+ l.bias)
558546
return x
559547
end
@@ -613,9 +601,9 @@ function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity;
613601
return ResGatedGraphConv(A, B, U, V, b, σ)
614602
end
615603

616-
function compute_message(l::ResGatedGraphConv, di, dj)
617-
η = sigmoid.(di.Ax .+ dj.Bx)
618-
return η .* dj.Vx
604+
function compute_message(l::ResGatedGraphConv, xi, xj, e)
605+
η = sigmoid.(xi.Ax .+ xj.Bx)
606+
η .* xj.Vx
619607
end
620608

621609
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)

src/msgpass.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,15 @@ See also [`compute_message`](@ref) and [`update_node`](@ref).
5353
"""
5454
function propagate end
5555

56+
propagate(l, g::GNNGraph, aggr; xi=nothing, xj=nothing, e=nothing) =
57+
propagate(l, g, aggr, xi, xj, e)
5658

57-
function propagate(l, g::GNNGraph, aggr; x=nothing, xi=nothing, xj=nothing, e=nothing)
58-
if !isnothing(x)
59-
@assert isnothing(xi)
60-
@assert isnothing(xj)
61-
xi, xj = x, x
62-
end
59+
function propagate(l, g::GNNGraph, aggr, xi, xj, e)
6360
m = apply_edges(l, g, xi, xj, e)
6461
= aggregate_neighbors(g, aggr, m)
6562
return
6663
end
6764

68-
# TODO deprecate
69-
propagate(l, g::GNNGraph, aggr, x, e=nothing) = propagate(l, g, aggr; x, e)
70-
7165
## Step 1.
7266

7367
"""
@@ -106,8 +100,11 @@ _gather(x::Tuple, i) = map(x -> _gather(x, i), x)
106100
_gather(x::AbstractArray, i) = NNlib.gather(x, i)
107101
_gather(x::Nothing, i) = nothing
108102

103+
apply_edges(l, g::GNNGraph; xi=nothing, xj=nothing, e=nothing) =
104+
apply_edges(l, g, xi, xj, e)
105+
109106
apply_edges(l::GNNLayer, g::GNNGraph, xi, xj, e) =
110-
apply_edges((a...) -> compute_message(l, a...), g::GNNGraph, xi, xj, e)
107+
apply_edges((xi,xj,e) -> compute_message(l, xi, xj, e), g, xi, xj, e)
111108

112109
function apply_edges(f, g::GNNGraph, xi, xj, e)
113110
s, t = edge_index(g)
@@ -120,16 +117,32 @@ end
120117

121118
## Step 2
122119

123-
_scatter(aggr, e::NamedTuple, t) = map(e -> _scatter(aggr, e, t), e)
124-
_scatter(aggr, e::Tuple, t) = map(e -> _scatter(aggr, e, t), e)
125-
_scatter(aggr, e::AbstractArray, t) = NNlib.scatter(aggr, e, t)
126-
_scatter(aggr, e::Nothing, t) = nothing
120+
_scatter(aggr, m::NamedTuple, t) = map(m -> _scatter(aggr, m, t), m)
121+
_scatter(aggr, m::Tuple, t) = map(m -> _scatter(aggr, m, t), m)
122+
_scatter(aggr, m::AbstractArray, t) = NNlib.scatter(aggr, m, t)
127123

128-
function aggregate_neighbors(g::GNNGraph, aggr, e)
124+
function aggregate_neighbors(g::GNNGraph, aggr, m)
129125
s, t = edge_index(g)
130-
return _scatter(aggr, e, t)
126+
return _scatter(aggr, m, t)
131127
end
132128

133129
aggregate_neighbors(g::GNNGraph, aggr::Nothing, e) = nothing
134130

135131
### end steps ###
132+
133+
134+
135+
### SPECIALIZATIONS OF PROPAGATE ###
136+
copyxi(xi, xj, e) = xi
137+
copyxj(xi, xj, e) = xj
138+
ximulxj(xi, xj, e) = xi .* xj
139+
xiaddxj(xi, xj, e) = xi .+ xj
140+
141+
function propagate(::typeof(copyxj), g::GNNGraph, ::typeof(+), xi, xj, e)
142+
A = adjacency_matrix(g)
143+
return xj * A
144+
end
145+
146+
# TODO divide by degre
147+
# propagate(::typeof(copyxj), g::GNNGraph, ::typeof(mean), xi, xj, e)
148+

test/examples/node_classification_cora.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ for (layer, Layer) in [
9393
]
9494

9595
@show layer
96-
@time train_res, test_res = train(Layer, verbose=false)
96+
@time train_res, test_res = train(Layer, verbose=true)
9797
@test train_res.acc > 95
9898
@test test_res.acc > 70
9999
end

0 commit comments

Comments
 (0)