We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 401ae1a commit bba3834Copy full SHA for bba3834
src/layers/basic.jl
@@ -115,6 +115,7 @@ end
115
Base.iterate, Base.lastindex, Base.keys
116
117
Flux.functor(::Type{<:GNNChain}, c) = c.layers, ls -> GNNChain(ls...)
118
+Flux.functor(::Type{<:GNNChain}, c::Tuple) = c, ls -> GNNChain(ls...)
119
120
# input from graph
121
applylayer(l, g::GNNGraph) = GNNGraph(g, ndata=l(node_features(g)))
test/layers/basic.jl
@@ -64,5 +64,11 @@
64
wg = WithGraph(model, g, traingraph=true)
65
@test length(Flux.params(wg)) == length(Flux.params(model)) + length(Flux.params(g))
66
end
67
+
68
+ @testset "Flux restructure" begin
69
+ chain = GNNChain(GraphConv(2=>2))
70
+ params, restructure = Flux.destructure(chain)
71
+ restructure(params)
72
+ end
73
74
0 commit comments