Skip to content

Commit bba3834

Browse files
committed
fix GNNChain restructure bug
1 parent 401ae1a commit bba3834

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/layers/basic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ end
115115
Base.iterate, Base.lastindex, Base.keys
116116

117117
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
118+
Flux.functor(::Type{<:GNNChain}, c::Tuple) = c, ls -> GNNChain(ls...)
118119

119120
# input from graph
120121
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))

test/layers/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,11 @@
6464
wg = WithGraph(model, g, traingraph=true)
6565
@test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g))
6666
end
67+
68+
@testset "Flux restructure" begin
69+
chain = GNNChain(GraphConv(2=>2))
70+
params, restructure = Flux.destructure(chain)
71+
restructure(params)
72+
end
6773
end
6874

0 commit comments

Comments
 (0)