Skip to content

Commit 651c216

Browse files
Update GNNChain (#202)
* update GNNChain * cleanup * cleanup * update compat bounds * improve docstring * add tests * more tests
1 parent 987fd03 commit 651c216

File tree

3 files changed

+105
-58
lines changed

3 files changed

+105
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/layers/basic.jl

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,6 @@ WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph
4949
@functor WithGraph
5050
Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model,)
5151

52-
# Work around
53-
# https://github.com/FluxML/Flux.jl/issues/1733
54-
# Revisit after
55-
# https://github.com/FluxML/Flux.jl/pull/1742
56-
function Flux.destructure(m::WithGraph)
57-
@assert m.traingraph == false # TODO
58-
p, re = Flux.destructure(m.model)
59-
function re_withgraph(x)
60-
WithGraph(re(x), m.g, m.traingraph)
61-
end
62-
63-
return p, re_withgraph
64-
end
65-
6652
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
6753
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)
6854

@@ -85,74 +71,112 @@ and if names are given, `m[:name] == m[1]` etc.
8571
# Examples
8672
8773
```juliarepl
88-
julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));
74+
julia> using Flux, GraphNeuralNetworks
75+
76+
julia> m = GNNChain(GCNConv(2=>5),
77+
BatchNorm(5),
78+
x -> relu.(x),
79+
Dense(5, 4))
80+
GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))
8981
9082
julia> x = randn(Float32, 2, 3);
9183
92-
julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);
84+
julia> g = rand_graph(3, 6)
85+
GNNGraph:
86+
num_nodes = 3
87+
num_edges = 6
9388
9489
julia> m(g, x)
9590
4×3 Matrix{Float32}:
96-
0.157941 0.15443 0.193471
97-
0.0819516 0.0503105 0.122523
98-
0.225933 0.267901 0.241878
99-
-0.0134364 -0.0120716 -0.0172505
91+
-0.795592 -0.795592 -0.795592
92+
-0.736409 -0.736409 -0.736409
93+
0.994925 0.994925 0.994925
94+
0.857549 0.857549 0.857549
95+
96+
julia> m2 = GNNChain(enc = m,
97+
dec = DotDecoder())
98+
GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())
99+
100+
julia> m2(g, x)
101+
1×6 Matrix{Float32}:
102+
2.90053 2.90053 2.90053 2.90053 2.90053 2.90053
103+
104+
julia> m2[:enc](g, x) == m(g, x)
105+
true
100106
```
101107
"""
102-
struct GNNChain{T} <: GNNLayer
108+
struct GNNChain{T<:Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer
103109
layers::T
104-
105-
GNNChain(xs...) = new{typeof(xs)}(xs)
106-
107-
function GNNChain(; kw...)
108-
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
109-
isempty(kw) && return new{Tuple{}}(())
110-
new{typeof(values(kw))}(values(kw))
111-
end
112110
end
113111

114-
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,
115-
Base.iterate, Base.lastindex, Base.keys
112+
@functor GNNChain
116113

117-
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
118-
Flux.functor(::Type{<:GNNChain}, c::Tuple) = c, ls -> GNNChain(ls...)
114+
GNNChain(xs...) = GNNChain(xs)
119115

120-
# input from graph
121-
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
122-
applylayer(l::GNNLayer, g::GNNGraph) = l(g)
116+
function GNNChain(; kw...)
117+
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
118+
isempty(kw) && return GNNChain(())
119+
GNNChain(values(kw))
120+
end
121+
122+
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,
123+
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
124+
125+
(c::GNNChain)(g::GNNGraph, x) = _applychain(c.layers, g, x)
126+
(c::GNNChain)(g::GNNGraph) = _applychain(c.layers, g)
127+
128+
## TODO see if this is faster for small chains
129+
## see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180
130+
# @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N}
131+
# symbols = vcat(:x, [gensym() for _ in 1:N])
132+
# calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N]
133+
# Expr(:block, calls...)
134+
# end
135+
# _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x)
136+
137+
function _applychain(layers, g::GNNGraph, x) # type-unstable path, helps compile times
138+
for l in layers
139+
x = _applylayer(l, g, x)
140+
end
141+
return x
142+
end
123143

124-
# explicit input
125-
applylayer(l, g::GNNGraph, x) = l(x)
126-
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
144+
function _applychain(layers, g::GNNGraph) # type-unstable path, helps compile times
145+
for l in layers
146+
g = _applylayer(l, g)
147+
end
148+
return g
149+
end
127150

128-
# Handle Flux.Parallel
129-
applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=applylayer(l, g, node_features(g)))
130-
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
151+
# # explicit input
152+
_applylayer(l, g::GNNGraph, x) = l(x)
153+
_applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
131154

132155
# input from graph
133-
applychain(::Tuple{}, g::GNNGraph) = g
134-
applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g))
135-
136-
# explicit input
137-
applychain(::Tuple{}, g::GNNGraph, x) = x
138-
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
156+
_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
157+
_applylayer(l::GNNLayer, g::GNNGraph) = l(g)
139158

140-
(c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x)
141-
(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g)
159+
# # Handle Flux.Parallel
160+
_applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=_applylayer(l, g, node_features(g)))
142161

162+
function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray)
163+
closures = map(f -> (x -> _applylayer(f, g, x)), l.layers)
164+
return Parallel(l.connection, closures)(x)
165+
end
143166

144-
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
145-
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =
146-
GNNChain(; NamedTuple{Base.keys(c)[i]}(Tuple(c.layers)[i])...)
167+
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i])
168+
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =
169+
GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
147170

148171
function Base.show(io::IO, c::GNNChain)
149172
print(io, "GNNChain(")
150173
_show_layers(io, c.layers)
151174
print(io, ")")
152175
end
176+
153177
_show_layers(io, layers::Tuple) = join(io, layers, ", ")
154178
_show_layers(io, layers::NamedTuple) = join(io, ["$k = $v" for (k, v) in pairs(layers)], ", ")
155-
179+
_show_layers(io, layers::AbstractVector) = (print(io, "["); join(io, layers, ", "); print(io, "]"))
156180

157181
"""
158182
DotDecoder()
@@ -181,5 +205,5 @@ struct DotDecoder <: GNNLayer end
181205

182206
function (::DotDecoder)(g, x)
183207
check_num_nodes(g, x)
184-
apply_edges(xi_dot_xj, g, xi=x, xj=x)
208+
return apply_edges(xi_dot_xj, g, xi=x, xj=x)
185209
end

test/layers/basic.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
@testset "basic" begin
22
@testset "GNNChain" begin
33
n, din, d, dout = 10, 3, 4, 2
4+
deg = 4
45

5-
g = GNNGraph(random_regular_graph(n, 4),
6+
g = GNNGraph(random_regular_graph(n, deg),
67
graph_type=GRAPH_T,
78
ndata= randn(Float32, din, n))
8-
9+
x = g.ndata.x
10+
911
gnn = GNNChain(GCNConv(din => d),
1012
BatchNorm(d),
1113
x -> tanh.(x),
@@ -17,6 +19,27 @@
1719

1820
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])
1921

22+
@testset "constructor with names" begin
23+
m = GNNChain(GCNConv(din=>d),
24+
BatchNorm(d),
25+
x -> relu.(x),
26+
Dense(d, dout))
27+
28+
m2 = GNNChain(enc = m,
29+
dec = DotDecoder())
30+
31+
@test m2[:enc] === m
32+
@test m2(g, x) == m2[:dec](g, m2[:enc](g, x))
33+
end
34+
35+
@testset "constructor with vector" begin
36+
m = GNNChain(GCNConv(din=>d),
37+
BatchNorm(d),
38+
x -> relu.(x),
39+
Dense(d, dout))
40+
m2 = GNNChain([m.layers...])
41+
@test m2(g, x) == m(g, x)
42+
end
2043

2144
@testset "Parallel" begin
2245
AddResidual(l) = Parallel(+, identity, l)

0 commit comments

Comments
 (0)