Skip to content

Commit 51cd1e9

Browse files
Merge pull request #46 from CarloLucibello/cl/dev
implement ResGatedGraphConv and support Parallel in GNNChain
2 parents c514d27 + 1bfb633 commit 51cd1e9

File tree

11 files changed

+175
-43
lines changed

11 files changed

+175
-43
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Carlo Lucibello and contributors"]
44
version = "0.1.1"
55

66
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
78
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
@@ -20,6 +21,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2021
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2122

2223
[compat]
24+
Adapt = "3"
2325
CUDA = "3.3"
2426
ChainRulesCore = "1"
2527
DataStructures = "0.18"

docs/src/models.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,26 @@ X = randn(Float32, din, 10)
7878
model = GNNChain(GCNConv(din => d),
7979
BatchNorm(d),
8080
x -> relu.(x),
81-
GraphConv(d => d, relu),
81+
GCNConv(d => d, relu),
8282
Dropout(0.5),
8383
Dense(d, dout))
8484

85-
y = model(g, X)
85+
y = model(g, X) # output size: (dout, g.num_nodes)
8686
```
8787

8888
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.
89+
90+
A `GNNChain` oppurtunely propagates the graph into the branches created by the `Flux.Parallel` layer:
91+
92+
```julia
93+
AddResidual(l) = Parallel(+, identity, l) # implementing a skip/residual connection
94+
95+
model = GNNChain( ResGatedGraphConv(din => d, relu),
96+
AddResidual(ResGatedGraphConv(d => d, relu)),
97+
AddResidual(ResGatedGraphConv(d => d, relu)),
98+
AddResidual(ResGatedGraphConv(d => d, relu)),
99+
GlobalPooling(mean),
100+
Dense(d, dout))
101+
102+
y = model(g, X) # output size: (dout, g.num_graphs)
103+
```

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ export
4646
GINConv,
4747
GraphConv,
4848
NNConv,
49+
ResGatedGraphConv,
4950
SAGEConv,
5051

5152
# layers/pool
@@ -62,5 +63,6 @@ include("msgpass.jl")
6263
include("layers/basic.jl")
6364
include("layers/conv.jl")
6465
include("layers/pool.jl")
66+
include("deprecations.jl")
6567

6668
end

src/deprecations.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Deprecated in v0.1
2+
3+
@deprecate GINConv(nn; eps=0, aggr=+) GINConv(nn, eps; aggr)

src/layers/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
6363
applylayer(l, g::GNNGraph, x) = l(x)
6464
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
6565

66+
# Handle Flux.Parallel
67+
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
68+
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs)
69+
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)
70+
71+
6672
applychain(::Tuple{}, g::GNNGraph, x) = x
6773
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
6874

src/layers/conv.jl

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ with ``z_i`` a normalization factor.
224224
225225
- `in`: The dimension of input features.
226226
- `out`: The dimension of output features.
227-
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
227+
- `bias`: Learn the additive bias if true.
228228
- `heads`: Number attention heads.
229229
- `concat`: Concatenate layer output or not. If not, layer output is averaged over the heads.
230230
- `negative_slope`: The parameter of LeakyReLU.
@@ -407,7 +407,7 @@ end
407407

408408

409409
@doc raw"""
410-
GINConv(f; eps = 0f0)
410+
GINConv(f, ϵ; aggr=+)
411411
412412
Graph Isomorphism convolutional layer from paper [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)
413413
@@ -420,30 +420,38 @@ where ``f_\Theta`` typically denotes a learnable function, e.g. a linear layer o
420420
# Arguments
421421
422422
- `f`: A (possibly learnable) function acting on node features.
423-
- `eps`: Weighting factor.
423+
- `ϵ`: Weighting factor.
424424
"""
425425
struct GINConv{R<:Real} <: GNNLayer
426426
nn
427-
eps::R
427+
ϵ::R
428+
aggr
428429
end
429430

430431
@functor GINConv
431-
Flux.trainable(l::GINConv) = (nn=l.nn,)
432+
Flux.trainable(l::GINConv) = (l.nn,)
433+
434+
GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
432435

433-
function GINConv(nn; eps=0f0)
434-
GINConv(nn, eps)
435-
end
436436

437437
compute_message(l::GINConv, x_i, x_j, e_ij) = x_j
438-
update_node(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
438+
update_node(l::GINConv, m, x) = l.nn((1 + ofeltype(x, l.ϵ)) * x + m)
439439

440440
function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
441441
check_num_nodes(g, X)
442-
X, _ = propagate(l, g, +, X)
442+
X, _ = propagate(l, g, l.aggr, X)
443443
X
444444
end
445445

446446

447+
function Base.show(io::IO, l::GINConv)
448+
print(io, "GINConv($(l.nn)")
449+
print(io, ", $(l.ϵ)")
450+
print(io, ")")
451+
end
452+
453+
454+
447455
@doc raw"""
448456
NNConv(in => out, f, σ=identity; aggr=+, bias=true, init=glorot_uniform)
449457
@@ -572,3 +580,77 @@ function Base.show(io::IO, l::SAGEConv)
572580
print(io, ", aggr=", l.aggr)
573581
print(io, ")")
574582
end
583+
584+
585+
@doc raw"""
586+
ResGatedGraphConv(in => out, act=identity; init=glorot_uniform, bias=true)
587+
588+
The residual gated graph convolutional operator from the [Residual Gated Graph ConvNets]((https://arxiv.org/abs/1711.07553)) paper.
589+
590+
The layer's forward pass is given by
591+
592+
```math
593+
\mathbf{x}_i' = act\big(U\mathbf{xhttps://github.com/ArtLabBocconi/deepJuliaNN}_i + \sum_{j \in N(i)} \eta_{ij} V \mathbf{x}_j\big),
594+
```
595+
where the edge gates ``\eta_{ij}`` are given by
596+
597+
```math
598+
\eta_{ij} = sigmoid(A\mathbf{x}_i + B\mathbf{x}_j).
599+
```
600+
601+
# Arguments
602+
603+
- `in`: The dimension of input features.
604+
- `out`: The dimension of output features.
605+
- `act`: Activation function.
606+
- `init`: Weight matrices' initializing function.
607+
- `bias`: Learn an additive bias if true.
608+
"""
609+
struct ResGatedGraphConv <: GNNLayer
610+
A
611+
B
612+
U
613+
V
614+
bias
615+
σ
616+
end
617+
618+
@functor ResGatedGraphConv
619+
620+
function ResGatedGraphConv(ch::Pair{Int,Int}, σ=identity;
621+
init=glorot_uniform, bias::Bool=true)
622+
in, out = ch
623+
A = init(out, in)
624+
B = init(out, in)
625+
U = init(out, in)
626+
V = init(out, in)
627+
b = bias ? Flux.create_bias(A, true, out) : false
628+
return ResGatedGraphConv(A, B, U, V, b, σ)
629+
end
630+
631+
function compute_message(l::ResGatedGraphConv, di, dj)
632+
η = sigmoid.(di.Ax .+ dj.Bx)
633+
return η .* dj.Vx
634+
end
635+
636+
update_node(l::ResGatedGraphConv, m, x) = m
637+
638+
function (l::ResGatedGraphConv)(g::GNNGraph, x::AbstractMatrix)
639+
check_num_nodes(g, x)
640+
641+
Ax = l.A * x
642+
Bx = l.B * x
643+
Vx = l.V * x
644+
645+
m, _ = propagate(l, g, +, (; Ax, Bx, Vx))
646+
647+
return l.σ.(l.U*x .+ m .+ l.bias)
648+
end
649+
650+
651+
function Base.show(io::IO, l::ResGatedGraphConv)
652+
out_channel, in_channel = size(l.A)
653+
print(io, "ResGatedGraphConv(", in_channel, "=>", out_channel)
654+
l.σ == identity || print(io, ", ", l.σ)
655+
print(io, ")")
656+
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,6 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
6767
end
6868
return data
6969
end
70+
71+
72+
ofeltype(x, y) = convert(float(eltype(x)), y)

test/examples/node_classification_cora.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ end
1717

1818
# arguments for the `train` function
1919
Base.@kwdef mutable struct Args
20-
η = 1f-3 # learning rate
21-
epochs = 20 # number of epochs
20+
η = 5f-3 # learning rate
21+
epochs = 10 # number of epochs
2222
seed = 17 # set seed > 0 for reproducibility
2323
usecuda = false # if true use cuda (if available)
24-
nhidden = 128 # dimension of hidden features
24+
nhidden = 64 # dimension of hidden features
2525
end
2626

2727
function train(Layer; verbose=false, kws...)
@@ -49,7 +49,7 @@ function train(Layer; verbose=false, kws...)
4949

5050
## DEFINE MODEL
5151
model = GNNChain(Layer(nin, nhidden),
52-
Dropout(0.5),
52+
# Dropout(0.5),
5353
Layer(nhidden, nhidden),
5454
Dense(nhidden, nout)) |> device
5555

@@ -70,8 +70,8 @@ function train(Layer; verbose=false, kws...)
7070
= model(g, X)
7171
logitcrossentropy(ŷ[:,train_ids], ytrain)
7272
end
73-
verbose && report(epoch)
7473
Flux.Optimise.update!(opt, ps, gs)
74+
verbose && report(epoch)
7575
end
7676

7777
train_res = eval_loss_accuracy(X, y, train_ids, model, g)
@@ -84,15 +84,16 @@ for Layer in [
8484
(nin, nout) -> GraphConv(nin => nout, relu, aggr=mean),
8585
(nin, nout) -> SAGEConv(nin => nout, relu),
8686
(nin, nout) -> GATConv(nin => nout, relu),
87-
(nin, nout) -> GATConv(nin => nout÷2, relu, heads=2),
88-
(nin, nout) -> GINConv(Dense(nin, nout, relu)),
89-
(nin, nout) -> ChebConv(nin => nout, 3),
87+
(nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr=mean),
88+
(nin, nout) -> ChebConv(nin => nout, 2),
89+
(nin, nout) -> ResGatedGraphConv(nin => nout, relu),
9090
# (nin, nout) -> NNConv(nin => nout), # needs edge features
9191
# (nin, nout) -> GatedGraphConv(nout, 2), # needs nin = nout
9292
# (nin, nout) -> EdgeConv(Dense(2nin, nout, relu)), # Fits the traning set but does not generalize well
9393
]
94-
train_res, test_res = train(Layer, verbose=true)
95-
# @show Layer(2,2) train_res, test_res
94+
95+
# @show Layer(2,2)
96+
train_res, test_res = train(Layer, verbose=false)
9697
@test train_res.acc > 95
9798
@test test_res.acc > 70
9899
end

test/layers/basic.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,34 @@
22
@testset "GNNChain" begin
33
n, din, d, dout = 10, 3, 4, 2
44

5-
g = GNNGraph(random_regular_graph(n, 4), graph_type=GRAPH_T)
5+
g = GNNGraph(random_regular_graph(n, 4),
6+
graph_type=GRAPH_T,
7+
ndata= randn(Float32, din, n))
68

79
gnn = GNNChain(GCNConv(din => d),
810
BatchNorm(d),
9-
x -> relu.(x),
10-
GraphConv(d => d, relu),
11+
x -> tanh.(x),
12+
GraphConv(d => d, tanh),
1113
Dropout(0.5),
1214
Dense(d, dout))
15+
16+
testmode!(gnn)
1317

14-
X = randn(Float32, din, n)
18+
test_layer(gnn, g, rtol=1e-5)
1519

16-
y = gnn(g, X)
17-
18-
@test y isa Matrix{Float32}
19-
@test size(y) == (dout, n)
2020

21-
@test length(params(gnn)) == 9
22-
23-
gs = gradient(x -> sum(gnn(g, x)), X)[1]
24-
@test gs isa Matrix{Float32}
25-
@test size(gs) == size(X)
21+
@testset "Parallel" begin
22+
AddResidual(l) = Parallel(+, identity, l)
23+
24+
gnn = GNNChain(ResGatedGraphConv(din => d, tanh),
25+
BatchNorm(d),
26+
AddResidual(ResGatedGraphConv(d => d, tanh)),
27+
BatchNorm(d),
28+
Dense(d, dout))
2629

27-
gs = gradient(() -> sum(gnn(g, X)), Flux.params(gnn))
28-
for p in Flux.params(gnn)
29-
@test eltype(gs[p]) == Float32
30-
@test size(gs[p]) == size(p)
30+
testmode!(gnn)
31+
32+
test_layer(gnn, g, rtol=1e-5)
3133
end
3234
end
3335
end

test/layers/conv.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@
111111

112112
@testset "GINConv" begin
113113
nn = Dense(in_channel, out_channel)
114-
eps = 0.001f0
115-
l = GINConv(nn, eps=eps)
114+
115+
l = GINConv(nn, 0.01f0, aggr=mean)
116116
for g in test_graphs
117-
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes), exclude_grad_fields=[:eps])
117+
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
118118
end
119119

120120
@test !in(:eps, Flux.trainable(l))
@@ -149,4 +149,17 @@
149149
test_layer(l, g, rtol=1e-5, outsize=(out_channel, g.num_nodes))
150150
end
151151
end
152+
153+
154+
@testset "ResGatedGraphConv" begin
155+
l = ResGatedGraphConv(in_channel => out_channel)
156+
for g in test_graphs
157+
test_layer(l, g, rtol=1e-5,)
158+
end
159+
160+
l = ResGatedGraphConv(in_channel => out_channel, tanh, bias=false)
161+
for g in test_graphs
162+
test_layer(l, g, rtol=1e-5,)
163+
end
164+
end
152165
end

0 commit comments

Comments
 (0)