Skip to content

Commit 6239b48

Browse files
implement GNNChain
1 parent d57fb34 commit 6239b48

File tree

12 files changed

+120
-35
lines changed

12 files changed

+120
-35
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ DataStructures = "0.18"
2525
Flux = "0.12"
2626
KrylovKit = "0.5"
2727
LightGraphs = "1.3"
28+
MacroTools = "0.5"
2829
NNlib = "0.7"
2930
NNlibCUDA = "0.1"
3031
julia = "1.6"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@ makedocs(;
99
pages = ["Home" => "index.md",
1010
"GNNGraph" => "gnngraph.md",
1111
"Message Passing" => "messagepassing.md",
12-
"Building Models" => "models.md",
12+
"Model Building" => "models.md",
1313
"API Reference" =>
1414
[
1515
"GNNGraph" => "api/gnngraph.md",
16+
"Basic Layers" => "api/basic.md",
1617
"Convolutional Layers" => "api/conv.md",
1718
"Pooling Layers" => "api/pool.md",
1819
],

docs/src/api/basic.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Basic Layers
2+
3+
```@index
4+
Order = [:type, :function]
5+
Pages = ["api/basics.md"]
6+
```
7+
8+
```@autodocs
9+
Modules = [GraphNeuralNetworks]
10+
Pages = ["layers/basic.jl"]
11+
Private = false
12+
```

docs/src/models.md

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,52 @@
11
# Models
2+
3+
## Explicit modeling
4+
5+
```julia
6+
using Flux, GraphNeuralNetworks
7+
using Flux: @functor
8+
9+
struct GNN
10+
conv1
11+
bn
12+
conv2
13+
dropout
14+
dense
15+
end
16+
17+
@functor GNN
18+
19+
function GNN(din, d, dout)
20+
GNN(GCNConv(din => d),
21+
BatchNorm(d),
22+
GraphConv(d => d, relu),
23+
Dropout(0.5),
24+
Dense(d, dout))
25+
end
26+
27+
function (model::GNN)(g::GNNGraph, x)
28+
x = model.conv1(g, x)
29+
x = relu.(model.bn(x))
30+
x = model.conv2(g, x)
31+
x = model.dropout(x)
32+
x = model.dense(x)
33+
return x
34+
end
35+
36+
din, d, dout = 3, 4, 2
37+
g = GNNGraph(random_regular_graph(10, 4), graph_type=GRAPH_T)
38+
X = randn(Float32, din, 10)
39+
model = GNN(din, d, dout)
40+
y = model(g, X)
41+
```
42+
43+
## Compact modeling with GNNChains
44+
45+
```julia
46+
model = GNNChain(GCNConv(din => d),
47+
BatchNorm(d),
48+
x -> relu.(x),
49+
GraphConv(d => d, relu),
50+
Dropout(0.5),
51+
Dense(d, dout))
52+
```

examples/cora.jl

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,6 @@ using Statistics, Random
99
using CUDA
1010
CUDA.allowscalar(false)
1111

12-
struct GNN
13-
conv1
14-
conv2
15-
dense
16-
end
17-
18-
@functor GNN
19-
20-
function GNN(; nin, nhidden, nout)
21-
GNN(GCNConv(nin => nhidden, relu),
22-
GCNConv(nhidden => nhidden, relu),
23-
Dense(nhidden, nout))
24-
end
25-
26-
function (net::GNN)(g, x)
27-
x = net.conv1(g, x)
28-
x = dropout(x, 0.5)
29-
x = net.conv2(g, x)
30-
x = net.dense(x)
31-
return x
32-
end
33-
3412
function eval_loss_accuracy(X, y, ids, model, g)
3513
= model(g, X)
3614
l = logitcrossentropy(ŷ[:,ids], y[:,ids])
@@ -63,6 +41,7 @@ function train(; kws...)
6341
@info "Training on CPU"
6442
end
6543

44+
# LOAD DATA
6645
data = Cora.dataset()
6746
g = GNNGraph(data.adjacency_list) |> device
6847
X = data.node_features |> device
@@ -72,14 +51,20 @@ function train(; kws...)
7251
test_ids = data.test_indices |> device
7352
ytrain = y[:,train_ids]
7453

75-
model = GNN(nin=size(X,1),
76-
nhidden=args.nhidden,
77-
nout=data.num_classes) |> device
54+
nin, nhidden, nout = size(X,1), args.nhidden, data.num_classes
55+
56+
## DEFINE MODEL
57+
model = GNNGraph(GCNConv(nin => nhidden, relu),
58+
Dropout(0.5)
59+
GCNConv(nhidden => nhidden, relu),
60+
Dense(nhidden, nout)) |> device
61+
7862
ps = Flux.params(model)
7963
opt = ADAM(args.η)
8064

8165
@info "NUM NODES: $(g.num_nodes) NUM EDGES: $(g.num_edges)"
8266

67+
## LOGGING FUNCTION
8368
function report(epoch)
8469
train = eval_loss_accuracy(X, y, train_ids, model, g)
8570
test = eval_loss_accuracy(X, y, test_ids, model, g)

src/GraphNeuralNetworks.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
module GraphNeuralNetworks
22

3+
using Core: apply_type
34
using NNlib: similar
45
using LinearAlgebra: similar, fill!
56
using Statistics: mean
67
using LinearAlgebra
78
using SparseArrays
89
import KrylovKit
10+
using Base: tail
911
using CUDA
1012
using Flux
1113
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
@@ -27,7 +29,8 @@ export
2729
# from LightGraphs
2830
adjacency_matrix,
2931

30-
# layers/msgpass
32+
# msgpass
33+
# update, update_edge, update_global, message, propagate,
3134

3235
# layers/basic
3336
GNNLayer,
@@ -54,11 +57,9 @@ export
5457
include("gnngraph.jl")
5558
include("graph_conversions.jl")
5659
include("utils.jl")
57-
58-
include("layers/msgpass.jl")
60+
include("msgpass.jl")
5961
include("layers/basic.jl")
6062
include("layers/conv.jl")
6163
include("layers/pool.jl")
6264

63-
6465
end

src/layers/basic.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,15 @@ end
5252
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,
5353
Base.iterate, Base.lastindex, Base.keys
5454

55-
functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
55+
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
5656

57-
applychain(::Tuple{}, x) = x
58-
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
57+
applylayer(l, g::GNNGraph, x) = l(x)
58+
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
5959

60-
(c::GNNChain)(x) = applychain(Tuple(c.layers), x)
60+
applychain(::Tuple{}, g::GNNGraph, x) = x
61+
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
62+
63+
(c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x)
6164

6265
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
6366
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =
File renamed without changes.
File renamed without changes.

test/layers/basic.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,34 @@
11
@testset "basic" begin
2+
@testset "GNNChain" begin
3+
n, din, d, dout = 10, 3, 4, 2
4+
5+
g = GNNGraph(random_regular_graph(n, 4), graph_type=GRAPH_T)
6+
7+
gnn = GNNChain(GCNConv(din => d),
8+
BatchNorm(d),
9+
x -> relu.(x),
10+
GraphConv(d => d, relu),
11+
Dropout(0.5),
12+
Dense(d, dout))
13+
14+
X = randn(Float32, din, n)
215

16+
y = gnn(g, X)
17+
18+
@test y isa Matrix{Float32}
19+
@test size(y) == (dout, n)
20+
21+
@test length(params(gnn)) == 9
22+
23+
gs = gradient(x -> sum(gnn(g, x)), X)[1]
24+
@test gs isa Matrix{Float32}
25+
@test size(gs) == size(X)
26+
27+
gs = gradient(() -> sum(gnn(g, X)), Flux.params(gnn))
28+
for p in Flux.params(gnn)
29+
@test eltype(gs[p]) == Float32
30+
@test size(gs[p]) == size(p)
31+
end
32+
end
333
end
34+

0 commit comments

Comments
 (0)