@@ -13,7 +13,7 @@ abstract type GNNLayer end
13
13
14
14
15
15
"""
16
- WithGraph(model, g::GNNGraph; traingraph=false)
16
+ WithGraph(model, g::GNNGraph; traingraph=false)
17
17
18
18
A type wrapping the `model` and tying it to the graph `g`.
19
19
In the forward pass, can only take feature arrays as inputs,
@@ -38,17 +38,31 @@ x2 = rand(Float32, 2, 4)
38
38
@assert wg(g2, x2) == model(g2, x2)
39
39
```
40
40
"""
41
- struct WithGraph{M}
42
- model:: M
43
- g:: GNNGraph
44
- traingraph:: Bool
41
+ struct WithGraph{M, G <: GNNGraph }
42
+ model:: M
43
+ g:: G
44
+ traingraph:: Bool
45
45
end
46
46
47
47
WithGraph (model, g:: GNNGraph ; traingraph= false ) = WithGraph (model, g, traingraph)
48
48
49
49
@functor WithGraph
50
50
Flux. trainable (l:: WithGraph ) = l. traingraph ? (l. model, l. g) : (l. model,)
51
51
52
+ # Work around
53
+ # https://github.com/FluxML/Flux.jl/issues/1733
54
+ # Revisit after
55
+ # https://github.com/FluxML/Flux.jl/pull/1742
56
+ function Flux. destructure (m:: WithGraph )
57
+ @assert m. traingraph == false # TODO
58
+ p, re = Flux. destructure (m. model)
59
+ function re_withgraph (x)
60
+ WithGraph (re (x), m. g, m. traingraph)
61
+ end
62
+
63
+ return p, re_withgraph
64
+ end
65
+
52
66
(l:: WithGraph )(g:: GNNGraph , x... ; kws... ) = l. model (g, x... ; kws... )
53
67
(l:: WithGraph )(x... ; kws... ) = l. model (l. g, x... ; kws... )
54
68
@@ -86,15 +100,15 @@ julia> m(g, x)
86
100
```
87
101
"""
88
102
struct GNNChain{T} <: GNNLayer
89
- layers:: T
90
-
91
- GNNChain (xs... ) = new {typeof(xs)} (xs)
92
-
93
- function GNNChain (; kw... )
94
- :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
95
- isempty (kw) && return new {Tuple{}} (())
96
- new {typeof(values(kw))} (values (kw))
97
- end
103
+ layers:: T
104
+
105
+ GNNChain (xs... ) = new {typeof(xs)} (xs)
106
+
107
+ function GNNChain (; kw... )
108
+ :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
109
+ isempty (kw) && return new {Tuple{}} (())
110
+ new {typeof(values(kw))} (values (kw))
111
+ end
98
112
end
99
113
100
114
@forward GNNChain. layers Base. getindex, Base. length, Base. first, Base. last,
0 commit comments