@@ -49,20 +49,6 @@ WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph
49
49
@functor WithGraph
50
50
Flux. trainable (l:: WithGraph ) = l. traingraph ? (; l. model, l. g) : (; l. model,)
51
51
52
- # Work around
53
- # https://github.com/FluxML/Flux.jl/issues/1733
54
- # Revisit after
55
- # https://github.com/FluxML/Flux.jl/pull/1742
56
- function Flux. destructure (m:: WithGraph )
57
- @assert m. traingraph == false # TODO
58
- p, re = Flux. destructure (m. model)
59
- function re_withgraph (x)
60
- WithGraph (re (x), m. g, m. traingraph)
61
- end
62
-
63
- return p, re_withgraph
64
- end
65
-
66
52
(l:: WithGraph )(g:: GNNGraph , x... ; kws... ) = l. model (g, x... ; kws... )
67
53
(l:: WithGraph )(x... ; kws... ) = l. model (l. g, x... ; kws... )
68
54
@@ -85,74 +71,112 @@ and if names are given, `m[:name] == m[1]` etc.
85
71
# Examples
86
72
87
73
```juliarepl
88
- julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));
74
+ julia> using Flux, GraphNeuralNetworks
75
+
76
+ julia> m = GNNChain(GCNConv(2=>5),
77
+ BatchNorm(5),
78
+ x -> relu.(x),
79
+ Dense(5, 4))
80
+ GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))
89
81
90
82
julia> x = randn(Float32, 2, 3);
91
83
92
- julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);
84
+ julia> g = rand_graph(3, 6)
85
+ GNNGraph:
86
+ num_nodes = 3
87
+ num_edges = 6
93
88
94
89
julia> m(g, x)
95
90
4×3 Matrix{Float32}:
96
- 0.157941 0.15443 0.193471
97
- 0.0819516 0.0503105 0.122523
98
- 0.225933 0.267901 0.241878
99
- -0.0134364 -0.0120716 -0.0172505
91
+ -0.795592 -0.795592 -0.795592
92
+ -0.736409 -0.736409 -0.736409
93
+ 0.994925 0.994925 0.994925
94
+ 0.857549 0.857549 0.857549
95
+
96
+ julia> m2 = GNNChain(enc = m,
97
+ dec = DotDecoder())
98
+ GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())
99
+
100
+ julia> m2(g, x)
101
+ 1×6 Matrix{Float32}:
102
+ 2.90053 2.90053 2.90053 2.90053 2.90053 2.90053
103
+
104
+ julia> m2[:enc](g, x) == m(g, x)
105
+ true
100
106
```
101
107
"""
102
- struct GNNChain{T} <: GNNLayer
108
+ struct GNNChain{T<: Union{Tuple, NamedTuple, AbstractVector} } <: GNNLayer
103
109
layers:: T
104
-
105
- GNNChain (xs... ) = new {typeof(xs)} (xs)
106
-
107
- function GNNChain (; kw... )
108
- :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
109
- isempty (kw) && return new {Tuple{}} (())
110
- new {typeof(values(kw))} (values (kw))
111
- end
112
110
end
113
111
114
- @forward GNNChain. layers Base. getindex, Base. length, Base. first, Base. last,
115
- Base. iterate, Base. lastindex, Base. keys
112
+ @functor GNNChain
116
113
117
- Flux. functor (:: Type{<:GNNChain} , c) = c. layers, ls -> GNNChain (ls... )
118
- Flux. functor (:: Type{<:GNNChain} , c:: Tuple ) = c, ls -> GNNChain (ls... )
114
+ GNNChain (xs... ) = GNNChain (xs)
119
115
120
- # input from graph
121
- applylayer (l, g:: GNNGraph ) = GNNGraph (g, ndata= l (node_features (g)))
122
- applylayer (l:: GNNLayer , g:: GNNGraph ) = l (g)
116
+ function GNNChain (; kw... )
117
+ :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
118
+ isempty (kw) && return GNNChain (())
119
+ GNNChain (values (kw))
120
+ end
121
+
122
+ @forward GNNChain. layers Base. getindex, Base. length, Base. first, Base. last,
123
+ Base. iterate, Base. lastindex, Base. keys, Base. firstindex
124
+
125
+ (c:: GNNChain )(g:: GNNGraph , x) = _applychain (c. layers, g, x)
126
+ (c:: GNNChain )(g:: GNNGraph ) = _applychain (c. layers, g)
127
+
128
+ # # TODO see if this is faster for small chains
129
+ # # see https://github.com/FluxML/Flux.jl/pull/1809#discussion_r781691180
130
+ # @generated function _applychain(layers::Tuple{Vararg{<:Any,N}}, g::GNNGraph, x) where {N}
131
+ # symbols = vcat(:x, [gensym() for _ in 1:N])
132
+ # calls = [:($(symbols[i+1]) = _applylayer(layers[$i], $(symbols[i]))) for i in 1:N]
133
+ # Expr(:block, calls...)
134
+ # end
135
+ # _applychain(layers::NamedTuple, g, x) = _applychain(Tuple(layers), x)
136
+
137
+ function _applychain (layers, g:: GNNGraph , x) # type-unstable path, helps compile times
138
+ for l in layers
139
+ x = _applylayer (l, g, x)
140
+ end
141
+ return x
142
+ end
123
143
124
- # explicit input
125
- applylayer (l, g:: GNNGraph , x) = l (x)
126
- applylayer (l:: GNNLayer , g:: GNNGraph , x) = l (g, x)
144
+ function _applychain (layers, g:: GNNGraph ) # type-unstable path, helps compile times
145
+ for l in layers
146
+ g = _applylayer (l, g)
147
+ end
148
+ return g
149
+ end
127
150
128
- # Handle Flux.Parallel
129
- applylayer (l :: Parallel , g:: GNNGraph ) = GNNGraph (g, ndata = applylayer (l, g, node_features (g)) )
130
- applylayer (l:: Parallel , g:: GNNGraph , x:: AbstractArray ) = mapreduce (f -> applylayer (f, g, x), l . connection, l . layers )
151
+ # # explicit input
152
+ _applylayer (l , g:: GNNGraph , x ) = l (x )
153
+ _applylayer (l:: GNNLayer , g:: GNNGraph , x) = l ( g, x)
131
154
132
155
# input from graph
133
- applychain (:: Tuple{} , g:: GNNGraph ) = g
134
- applychain (fs:: Tuple , g:: GNNGraph ) = applychain (tail (fs), applylayer (first (fs), g))
135
-
136
- # explicit input
137
- applychain (:: Tuple{} , g:: GNNGraph , x) = x
138
- applychain (fs:: Tuple , g:: GNNGraph , x) = applychain (tail (fs), g, applylayer (first (fs), g, x))
156
+ _applylayer (l, g:: GNNGraph ) = GNNGraph (g, ndata= l (node_features (g)))
157
+ _applylayer (l:: GNNLayer , g:: GNNGraph ) = l (g)
139
158
140
- (c :: GNNChain )(g :: GNNGraph , x) = applychain ( Tuple (c . layers), g, x)
141
- (c :: GNNChain )( g:: GNNGraph ) = applychain ( Tuple (c . layers) , g)
159
+ # # Handle Flux.Parallel
160
+ _applylayer (l :: Parallel , g:: GNNGraph ) = GNNGraph (g, ndata = _applylayer (l , g, node_features (g)) )
142
161
162
+ function _applylayer (l:: Parallel , g:: GNNGraph , x:: AbstractArray )
163
+ closures = map (f -> (x -> _applylayer (f, g, x)), l. layers)
164
+ return Parallel (l. connection, closures)(x)
165
+ end
143
166
144
- Base. getindex (c:: GNNChain , i:: AbstractArray ) = GNNChain (c. layers[i]. .. )
145
- Base. getindex (c:: GNNChain{<:NamedTuple} , i:: AbstractArray ) =
146
- GNNChain (; NamedTuple {Base. keys(c)[i]} (Tuple (c. layers)[i])... )
167
+ Base. getindex (c:: GNNChain , i:: AbstractArray ) = GNNChain (c. layers[i])
168
+ Base. getindex (c:: GNNChain{<:NamedTuple} , i:: AbstractArray ) =
169
+ GNNChain (NamedTuple {keys(c)[i]} (Tuple (c. layers)[i]))
147
170
148
171
function Base. show (io:: IO , c:: GNNChain )
149
172
print (io, " GNNChain(" )
150
173
_show_layers (io, c. layers)
151
174
print (io, " )" )
152
175
end
176
+
153
177
_show_layers (io, layers:: Tuple ) = join (io, layers, " , " )
154
178
_show_layers (io, layers:: NamedTuple ) = join (io, [" $k = $v " for (k, v) in pairs (layers)], " , " )
155
-
179
+ _show_layers (io, layers :: AbstractVector ) = ( print (io, " [ " ); join (io, layers, " , " ); print (io, " ] " ))
156
180
157
181
"""
158
182
DotDecoder()
@@ -181,5 +205,5 @@ struct DotDecoder <: GNNLayer end
181
205
182
206
function (:: DotDecoder )(g, x)
183
207
check_num_nodes (g, x)
184
- apply_edges (xi_dot_xj, g, xi= x, xj= x)
208
+ return apply_edges (xi_dot_xj, g, xi= x, xj= x)
185
209
end
0 commit comments