Skip to content

Commit de14ded

Browse files
Merge pull request #50 from CarloLucibello/cl/fix
better handling of chain with only graph input
2 parents 867cc10 + 6b9c3f5 commit de14ded

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphNeuralNetworks"
22
uuid = "cffab07f-9bc2-4db1-8861-388f63bf7694"
33
authors = ["Carlo Lucibello and contributors"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/layers/basic.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,29 @@ end
6060

6161
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
6262

63+
# input from graph
64+
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
65+
applylayer(l::GNNLayer, g::GNNGraph) = l(g)
66+
67+
# explicit input
6368
applylayer(l, g::GNNGraph, x) = l(x)
6469
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
6570

6671
# Handle Flux.Parallel
72+
applylayer(l::Parallel, g::GNNGraph) = GNNGraph(g, ndata=applylayer(l, g, node_features(g)))
6773
applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylayer(f, g, x), l.connection, l.layers)
68-
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs)
69-
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)
7074

75+
# input from graph
76+
applychain(::Tuple{}, g::GNNGraph) = g
77+
applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g))
7178

79+
# explicit input
7280
applychain(::Tuple{}, g::GNNGraph, x) = x
7381
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
7482

7583
(c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x)
84+
(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g)
85+
7686

7787
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
7888
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =

test/layers/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131

3232
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])
3333
end
34+
35+
@testset "Only graph input" begin
36+
nin, nout = 2, 4
37+
ndata = rand(nin, 3)
38+
edata = rand(nin, 3)
39+
g = GNNGraph([1,1,2], [2, 3, 3], ndata=ndata, edata=edata)
40+
m = NNConv(nin => nout, Dense(2, nin*nout, tanh))
41+
chain = GNNChain(m)
42+
y = m(g, g.ndata.x, g.edata.e)
43+
@test m(g).ndata.x == y
44+
@test chain(g).ndata.x == y
45+
end
3446
end
3547
end
3648

0 commit comments

Comments
 (0)