Skip to content

Commit 0273b25

Browse files
add Parallel support in GNNChain
1 parent 3585cca commit 0273b25

File tree

10 files changed

+260
-140
lines changed

10 files changed

+260
-140
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "0.1.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"

docs/src/models.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,26 @@ X = randn(Float32, din, 10)
7878
model = GNNChain(GCNConv(din => d),
7979
BatchNorm(d),
8080
x -> relu.(x),
81-
GraphConv(d => d, relu),
81+
GCNConv(d => d, relu),
8282
Dropout(0.5),
8383
Dense(d, dout))
8484

85-
y = model(g, X)
85+
y = model(g, X) # output size: (dout, g.num_nodes)
8686
```
8787

8888
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.
89+
90+
A `GNNChain` oppurtunely propagates the graph into the branches created by the `Flux.Parallel` layer:
91+
92+
```julia
93+
AddResidual(l) = Parallel(+, identity, l)
94+
95+
model = GNNChain( AddResidual(ResGatedGraphConv(din => d, relu)),
96+
BatchNorm(d),
97+
AddResidual(ResGatedGraphConv(d => d, relu)),
98+
BatchNorm(d),
99+
GlobalPooling(mean),
100+
Dense(d, dout))
101+
102+
y = model(g, X) # output size: (dout, g.num_graphs)
103+
```

src/GraphNeuralNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ export
4646
GINConv,
4747
GraphConv,
4848
NNConv,
49+
ResGatedGraphConv,
4950
SAGEConv,
5051

5152
# layers/pool

src/layers/basic.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
6363
applylayer(l, g::GNNGraph, x) = l(x)
6464
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
6565

66+
# Handle Flux.Parallel
67+
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers)
68+
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs)
69+
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)
70+
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(l, g, x), l.connection, l.layers)
71+
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(l, g, x), l.connection, l.layers, xs)
72+
73+
74+
6675
applychain(::Tuple{}, g::GNNGraph, x) = x
6776
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
6877

src/layers/conv.jl

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.
224224
225225
- `in`: The dimension of input features.
226226
- `out`: The dimension of output features.
227-
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
227+
- `bias`: Learn the additive bias if true.
228228
- `heads`: Number attention heads.
229229
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
230230
- `negative_slope`: The parameter of LeakyReLU.
@@ -572,3 +572,77 @@ function Base.show(io::IO, l::SAGEConv)
572572
print(io, ", aggr=", l.aggr)
573573
print(io, ")")
574574
end
575+
576+
577+
@doc raw"""
578+
ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true)
579+
580+
The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets]((https://arxiv.org/abs/1711.07553)) paper.
581+
582+
The layer's forward pass is given by
583+
584+
```math
585+
\mathbf{x}_i' = act\big(U\mathbf{xhttps://github.com/ArtLabBocconi/deepJuliaNN}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big),
586+
```
587+
where the edge gates ``\eta_{ij}`` are given by
588+
589+
```math
590+
\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j).
591+
```
592+
593+
# Arguments
594+
595+
- `in`: The dimension of input features.
596+
- `out`: The dimension of output features.
597+
- `act`: Activation function.
598+
- `init`: Weight matrices' initializing function.
599+
- `bias`: Learn an additive bias if true.
600+
"""
601+
struct ResGatedGraphConv <: GNNLayer
602+
A
603+
B
604+
U
605+
V
606+
bias
607+
σ
608+
end
609+
610+
@functor ResGatedGraphConv
611+
612+
function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity;
613+
init=glorot_uniform, bias::Bool=true)
614+
in, out = ch
615+
A = init(out, in)
616+
B = init(out, in)
617+
U = init(out, in)
618+
V = init(out, in)
619+
b = bias ? Flux.create_bias(A, true, out) : false
620+
return ResGatedGraphConv(A, B, U, V, b, σ)
621+
end
622+
623+
function compute_message(l::ResGatedGraphConv, di, dj)
624+
η = sigmoid.(di.Ax .+ dj.Bx)
625+
return η .* dj.Vx
626+
end
627+
628+
update_node(l::ResGatedGraphConv, m, x) = m
629+
630+
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
631+
check_num_nodes(g, x)
632+
633+
Ax = l.A * x
634+
Bx = l.B * x
635+
Vx = l.V * x
636+
637+
m, _ = propagate(l, g, +, (; Ax, Bx, Vx))
638+
639+
return l.σ.(l.U*x .+ m .+ l.bias)
640+
end
641+
642+
643+
function Base.show(io::IO, l::ResGatedGraphConv)
644+
out_channel, in_channel = size(l.weight)
645+
print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel)
646+
l.σ == identity || print(io, ", ", l.σ)
647+
print(io, ")")
648+
end

test/examples/node_classification_cora.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ function train(Layer; verbose=false, kws...)
7070
= model(g, X)
7171
logitcrossentropy(ŷ[:,train_ids], ytrain)
7272
end
73-
verbose && report(epoch)
7473
Flux.Optimise.update!(opt, ps, gs)
74+
verbose && report(epoch)
7575
end
7676

7777
train_res = eval_loss_accuracy(X, y, train_ids, model, g)
@@ -87,11 +87,12 @@ for Layer in [
8787
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
8888
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
8989
(nin, nout) -> ChebConv(nin => nout, 3),
90+
(nin, nout) -> ResGatedGraphConv(nin => nout, relu),
9091
# (nin, nout) -> NNConv(nin => nout), # needs edge features
9192
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
9293
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
9394
]
94-
train_res, test_res = train(Layer, verbose=true)
95+
train_res, test_res = train(Layer, verbose=false)
9596
# @show Layer(2,2) train_res, test_res
9697
@test train_res.acc > 95
9798
@test test_res.acc > 70

test/layers/basic.jl

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,35 @@
22
@testset "GNNChain" begin
33
n, din, d, dout = 10, 3, 4, 2
44

5-
g = GNNGraph(random_regular_graph(n, 4), graph_type=GRAPH_T)
5+
g = GNNGraph(random_regular_graph(n, 4),
6+
graph_type=GRAPH_T,
7+
ndata= randn(Float32, din, n))
68

79
gnn = GNNChain(GCNConv(din => d),
810
BatchNorm(d),
9-
x -> relu.(x),
10-
GraphConv(d => d, relu),
11+
x -> tanh.(x),
12+
GraphConv(d => d, tanh),
1113
Dropout(0.5),
1214
Dense(d, dout))
15+
16+
testmode!(gnn)
1317

14-
X = randn(Float32, din, n)
18+
test_layer(gnn, g, rtol=1e-5) # exclude BN buffers
1519

16-
y = gnn(g, X)
17-
18-
@test y isa Matrix{Float32}
19-
@test size(y) == (dout, n)
2020

21-
@test length(params(gnn)) == 9
22-
23-
gs = gradient(x -> sum(gnn(g, x)), X)[1]
24-
@test gs isa Matrix{Float32}
25-
@test size(gs) == size(X)
21+
@testset "Parallel" begin
22+
AddResidual(l) = Parallel(+, identity, l)
23+
24+
gnn = GNNChain(AddResidual(ResGatedGraphConv(din => d, tanh)),
25+
BatchNorm(d),
26+
AddResidual(ResGatedGraphConv(d => d, tanh)),
27+
BatchNorm(d),
28+
Dense(d, dout))
2629

27-
gs = gradient(() -> sum(gnn(g, X)), Flux.params(gnn))
28-
for p in Flux.params(gnn)
29-
@test eltype(gs[p]) == Float32
30-
@test size(gs[p]) == size(p)
30+
testmode!(gnn)
31+
32+
test_layer(gnn, g, rtol=1e-5, verbose=true,
33+
exclude_grad_fields=[, :σ², ]) # exclude BN buffers
3134
end
3235
end
3336
end

0 commit comments

Comments
 (0)