Skip to content

Commit b7a6516

Browse files
better handling of chain with graph input
1 parent 867cc10 commit b7a6516

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/layers/basic.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ end
6060

6161
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
6262

63+
# input from graph
64+
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
65+
applylayer(l::GNNLayer, g::GNNGraph) = l(g)
66+
67+
# explicit input
6368
applylayer(l, g::GNNGraph, x) = l(x)
6469
applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)
6570

@@ -68,11 +73,17 @@ applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) = mapreduce(f -> applylay
6873
applylayer(l::Parallel, g::GNNGraph, xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> applylayer(f, g, x), l.connection, l.layers, xs)
6974
applylayer(l::Parallel, g::GNNGraph, xs::Tuple) = applylayer(l, g, xs...)
7075

76+
# input from graph
77+
applychain(::Tuple{}, g::GNNGraph) = g
78+
applychain(fs::Tuple, g::GNNGraph) = applychain(tail(fs), applylayer(first(fs), g))
7179

80+
# explicit input
7281
applychain(::Tuple{}, g::GNNGraph, x) = x
7382
applychain(fs::Tuple, g::GNNGraph, x) = applychain(tail(fs), g, applylayer(first(fs), g, x))
7483

7584
(c::GNNChain)(g::GNNGraph, x) = applychain(Tuple(c.layers), g, x)
85+
(c::GNNChain)(g::GNNGraph) = applychain(Tuple(c.layers), g)
86+
7687

7788
Base.getindex(c::GNNChain, i::AbstractArray) = GNNChain(c.layers[i]...)
7889
Base.getindex(c::GNNChain{<:NamedTuple}, i::AbstractArray) =

test/layers/basic.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@
3131

3232
test_layer(gnn, g, rtol=1e-5, exclude_grad_fields=[, :σ²])
3333
end
34+
35+
@testset "Only graph input" begin
36+
nin, nout = 2, 4
37+
ndata = rand(nin, 3)
38+
edata = rand(nin, 3)
39+
g = GNNGraph([1,1,2], [2, 3, 3], ndata=ndata, edata=edata)
40+
m = NNConv(nin => nout, Dense(2, nin*nout, tanh))
41+
chain = GNNChain(m)
42+
y = m(g, g.ndata.x, g.edata.e)
43+
@test m(g).ndata.x == y
44+
@test chain(g).ndata.x == y
45+
end
3446
end
3547
end
3648

0 commit comments

Comments
 (0)