@@ -93,3 +93,184 @@ LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)
93
93
function Base. show (io:: IO , l:: A3TGCN )
94
94
print (io, " A3TGCN($(l. in_dims) => $(l. out_dims) )" )
95
95
end
96
+
97
+ @concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
98
+ in_dims:: Int
99
+ out_dims:: Int
100
+ k:: Int
101
+ conv_x_r
102
+ conv_h_r
103
+ conv_x_z
104
+ conv_h_z
105
+ conv_x_h
106
+ conv_h_h
107
+ init_state:: Function
108
+ end
109
+
110
+ function GConvGRUCell (ch:: Pair{Int, Int} , k:: Int ; use_bias = true , init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
111
+ in_dims, out_dims = ch
112
+ # reset gate
113
+ conv_x_r = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
114
+ conv_h_r = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
115
+ # update gate
116
+ conv_x_z = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
117
+ conv_h_z = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
118
+ # hidden state
119
+ conv_x_h = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
120
+ conv_h_h = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
121
+ return GConvGRUCell (in_dims, out_dims, k, conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, init_state)
122
+ end
123
+
124
+ function (l:: GConvGRUCell )(g, (x, h), ps, st)
125
+ if h === nothing
126
+ h = l. init_state (l. out_dims, g. num_nodes)
127
+ end
128
+ xr, st_conv_xr = l. conv_x_r (g, x, ps. conv_x_r, st. conv_x_r)
129
+ hr, st_conv_hr = l. conv_h_r (g, h, ps. conv_h_r, st. conv_h_r)
130
+ r = xr .+ hr
131
+ r = NNlib. sigmoid_fast (r)
132
+ xz, st_conv_x_z = l. conv_x_z (g, x, ps. conv_x_z, st. conv_x_z)
133
+ hz, st_conv_h_z = l. conv_h_z (g, h, ps. conv_h_z, st. conv_h_z)
134
+ z = xz .+ hz
135
+ z = NNlib. sigmoid_fast (z)
136
+ xh, st_conv_x_h = l. conv_x_h (g, x, ps. conv_x_h, st. conv_x_h)
137
+ hh, st_conv_h_h = l. conv_h_h (g, r .* h, ps. conv_h_h, st. conv_h_h)
138
+ h̃ = xh .+ hh
139
+ h̃ = NNlib. tanh_fast (h)
140
+ h = (1 .- z). * h̃ + z.* h
141
+ return (h, h), (conv_x_r = st_conv_xr, conv_h_r = st_conv_hr, conv_x_z = st_conv_x_z, conv_h_z = st_conv_h_z, conv_x_h = st_conv_x_h, conv_h_h = st_conv_h_h)
142
+ end
143
+
144
+ function Base. show (io:: IO , l:: GConvGRUCell )
145
+ print (io, " GConvGRUCell($(l. in_dims) => $(l. out_dims) )" )
146
+ end
147
+
148
+ LuxCore. outputsize (l:: GConvGRUCell ) = (l. out_dims,)
149
+
150
+ GConvGRU (ch:: Pair{Int, Int} , k:: Int ; kwargs... ) = GNNLux. StatefulRecurrentCell (GConvGRUCell (ch, k; kwargs... ))
151
+
152
+ @concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}
153
+ in_dims:: Int
154
+ out_dims:: Int
155
+ k:: Int
156
+ conv_x_i
157
+ conv_h_i
158
+ dense_i
159
+ conv_x_f
160
+ conv_h_f
161
+ dense_f
162
+ conv_x_c
163
+ conv_h_c
164
+ dense_c
165
+ conv_x_o
166
+ conv_h_o
167
+ dense_o
168
+ init_state:: Function
169
+ end
170
+
171
+ function GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int ; use_bias = true , init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
172
+ in_dims, out_dims = ch
173
+ # input gate
174
+ conv_x_i = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
175
+ conv_h_i = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
176
+ dense_i = Dense (out_dims, 1 ; use_bias, init_weight, init_bias)
177
+ # forget gate
178
+ conv_x_f = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
179
+ conv_h_f = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
180
+ dense_f = Dense (out_dims, 1 ; use_bias, init_weight, init_bias)
181
+ # cell gate
182
+ conv_x_c = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
183
+ conv_h_c = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
184
+ dense_c = Dense (out_dims, 1 ; use_bias, init_weight, init_bias)
185
+ # output gate
186
+ conv_x_o = ChebConv (in_dims => out_dims, k; use_bias, init_weight, init_bias)
187
+ conv_h_o = ChebConv (out_dims => out_dims, k; use_bias, init_weight, init_bias)
188
+ dense_o = Dense (out_dims, 1 ; use_bias, init_weight, init_bias)
189
+ return GConvLSTMCell (in_dims, out_dims, k, conv_x_i, conv_h_i, dense_i, conv_x_f, conv_h_f, dense_f, conv_x_c, conv_h_c, dense_c, conv_x_o, conv_h_o, dense_o, init_state)
190
+ end
191
+
192
+ function (l:: GConvLSTMCell )(g, (x, m), ps, st)
193
+ if m === nothing
194
+ h = l. init_state (l. out_dims, g. num_nodes)
195
+ c = l. init_state (l. out_dims, g. num_nodes)
196
+ else
197
+ h, c = m
198
+ end
199
+
200
+ dense_i = StatefulLuxLayer {true} (l. dense_i, ps. dense_i, _getstate (st, :dense_i ))
201
+ dense_f = StatefulLuxLayer {true} (l. dense_f, ps. dense_f, _getstate (st, :dense_f ))
202
+ dense_c = StatefulLuxLayer {true} (l. dense_c, ps. dense_c, _getstate (st, :dense_c ))
203
+ dense_o = StatefulLuxLayer {true} (l. dense_o, ps. dense_o, _getstate (st, :dense_o ))
204
+
205
+ xi, st_conv_x_i = l. conv_x_i (g, x, ps. conv_x_i, st. conv_x_i)
206
+ hi, st_conv_h_i = l. conv_h_i (g, h, ps. conv_h_i, st. conv_h_i)
207
+ i = xi .+ hi .+ dense_i (c)
208
+ i = NNlib. sigmoid_fast (i)
209
+
210
+ xf, st_conv_x_f = l. conv_x_f (g, x, ps. conv_x_f, st. conv_x_f)
211
+ hf, st_conv_h_f = l. conv_h_f (g, h, ps. conv_h_f, st. conv_h_f)
212
+ f = xf .+ hf .+ dense_f (c)
213
+ f = NNlib. sigmoid_fast (f)
214
+
215
+ xc, st_conv_x_c = l. conv_x_c (g, x, ps. conv_x_c, st. conv_x_c)
216
+ hc, st_conv_h_c = l. conv_h_c (g, h, ps. conv_h_c, st. conv_h_c)
217
+ c = f .* c + i.* NNlib. tanh_fast (xc .+ hc .+ dense_c (c))
218
+
219
+ xo, st_conv_x_o = l. conv_x_o (g, x, ps. conv_x_o, st. conv_x_o)
220
+ ho, st_conv_h_o = l. conv_h_o (g, h, ps. conv_h_o, st. conv_h_o)
221
+ o = xo .+ ho .+ dense_o (c)
222
+ o = NNlib. sigmoid_fast (o)
223
+ h = o.* NNlib. tanh_fast (c)
224
+ return (h, (h, c)), (conv_x_i = st_conv_x_i, conv_h_i = st_conv_h_i, conv_x_f = st_conv_x_f, conv_h_f = st_conv_h_f, conv_x_c = st_conv_x_c, conv_h_c = st_conv_h_c, conv_x_o = st_conv_x_o, conv_h_o = st_conv_h_o)
225
+ end
226
+
227
+ function Base. show (io:: IO , l:: GConvLSTMCell )
228
+ print (io, " GConvLSTMCell($(l. in_dims) => $(l. out_dims) )" )
229
+ end
230
+
231
+ LuxCore. outputsize (l:: GConvLSTMCell ) = (l. out_dims,)
232
+
233
+ GConvLSTM (ch:: Pair{Int, Int} , k:: Int ; kwargs... ) = GNNLux. StatefulRecurrentCell (GConvLSTMCell (ch, k; kwargs... ))
234
+
235
+ @concrete struct DCGRUCell <: GNNContainerLayer{(:dconv_u, :dconv_r, :dconv_c)}
236
+ in_dims:: Int
237
+ out_dims:: Int
238
+ k:: Int
239
+ dconv_u
240
+ dconv_r
241
+ dconv_c
242
+ init_state:: Function
243
+ end
244
+
245
+ function DCGRUCell (ch:: Pair{Int, Int} , k:: Int ; use_bias = true , init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
246
+ in_dims, out_dims = ch
247
+ dconv_u = DConv ((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
248
+ dconv_r = DConv ((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
249
+ dconv_c = DConv ((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
250
+ return DCGRUCell (in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
251
+ end
252
+
253
+ function (l:: DCGRUCell )(g, (x, h), ps, st)
254
+ if h === nothing
255
+ h = l. init_state (l. out_dims, g. num_nodes)
256
+ end
257
+ h̃ = vcat (x, h)
258
+ z, st_dconv_u = l. dconv_u (g, h̃, ps. dconv_u, st. dconv_u)
259
+ z = NNlib. sigmoid_fast .(z)
260
+ r, st_dconv_r = l. dconv_r (g, h̃, ps. dconv_r, st. dconv_r)
261
+ r = NNlib. sigmoid_fast .(r)
262
+ ĥ = vcat (x, h .* r)
263
+ c, st_dconv_c = l. dconv_c (g, ĥ, ps. dconv_c, st. dconv_c)
264
+ c = NNlib. tanh_fast .(c)
265
+ h = z.* h + (1 .- z). * c
266
+ return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
267
+ end
268
+
269
+ function Base. show (io:: IO , l:: DCGRUCell )
270
+ print (io, " DCGRUCell($(l. in_dims) => $(l. out_dims) )" )
271
+ end
272
+
273
+ LuxCore. outputsize (l:: DCGRUCell ) = (l. out_dims,)
274
+
275
+ DCGRU (ch:: Pair{Int, Int} , k:: Int ; kwargs... ) = GNNLux. StatefulRecurrentCell (DCGRUCell (ch, k; kwargs... ))
276
+
0 commit comments