@@ -7,6 +7,73 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
7
7
return stack (y, dims = 2 )
8
8
end
9
9
10
+
11
+ """
12
+ GNNRecurrence(cell)
13
+
14
+ Construct a recurrent layer that applies the `cell`
15
+ to process an entire temporal sequence of node features at once.
16
+
17
+ # Forward
18
+
19
+ layer(g::GNNGraph, x, [state])
20
+
21
+ - `g`: The input graph.
22
+ - `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
23
+ - `state`: The initial state of the cell.
24
+ If not provided, it is generated by calling `Flux.initialstates(cell)`.
25
+
26
+ Applies the recurrent cell to each timestep of the input sequence and returns the output as
27
+ an array of size `out_features x timesteps x num_nodes`.
28
+
29
+ # Examples
30
+
31
+ ```jldoctest
32
+ julia> num_nodes, num_edges = 5, 10;
33
+
34
+ julia> d_in, d_out = 2, 3;
35
+
36
+ julia> timesteps = 5;
37
+
38
+ julia> g = rand_graph(num_nodes, num_edges);
39
+
40
+ julia> x = rand(Float32, d_in, timesteps, num_nodes);
41
+
42
+ julia> cell = GConvLSTMCell(d_in => d_out, 2)
43
+ GConvLSTMCell(2 => 3, 2) # 168 parameters
44
+
45
+ julia> layer = GNNRecurrence(cell)
46
+ GNNRecurrence(
47
+ GConvLSTMCell(2 => 3, 2), # 168 parameters
48
+ ) # Total: 24 arrays, 168 parameters, 2.023 KiB.
49
+
50
+ julia> y = layer(g, x);
51
+
52
+ julia> size(y) # (d_out, timesteps, num_nodes)
53
+ (3, 5, 5)
54
+ ```
55
+ """
56
+ struct GNNRecurrence{G} <: GNNLayer
57
+ cell:: G
58
+ end
59
+
60
+ Flux. @layer GNNRecurrence
61
+
62
+ Flux. initialstates (rnn:: GNNRecurrence ) = Flux. initialstates (rnn. cell)
63
+
64
+ function (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} ) where {T}
65
+ return rnn (g, x, initialstates (rnn))
66
+ end
67
+
68
+ function (rnn:: GNNRecurrence )(g:: GNNGraph , x:: AbstractArray{T,3} , state) where {T}
69
+ return scan (rnn. cell, g, x, state)
70
+ end
71
+
72
+ function Base. show (io:: IO , rnn:: GNNRecurrence )
73
+ print (io, " GNNRecurrence($(rnn. cell) )" )
74
+ end
75
+
76
+
10
77
"""
11
78
GConvGRUCell(in => out, k; [bias, init])
12
79
@@ -126,24 +193,13 @@ function Base.show(io::IO, cell::GConvGRUCell)
126
193
end
127
194
128
195
"""
129
- GConvGRU(in => out, k; kws...)
130
-
131
- The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell,
132
- used to process an entire temporal sequence of node features at once.
196
+ GConvGRU(args...; kws...)
133
197
134
- The arguments are the same as for [`GConvGRUCell`](@ref).
198
+ Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell.
199
+ It can be used to process an entire temporal sequence of node features at once.
135
200
136
- # Forward
137
-
138
- layer(g::GNNGraph, x, [h])
139
-
140
- - `g`: The input graph.
141
- - `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
142
- - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
143
- If not provided, it is assumed to be a matrix of zeros.
144
-
145
- Applies the recurrent cell to each timestep of the input sequence and returns the output as
146
- an array of size `out x timesteps x num_nodes`.
201
+ The arguments are passed to the [`GConvGRUCell`](@ref) constructor.
202
+ See [`GNNRecurrence`](@ref) for more details.
147
203
148
204
# Examples
149
205
@@ -158,33 +214,18 @@ julia> g = rand_graph(num_nodes, num_edges);
158
214
159
215
julia> x = rand(Float32, d_in, timesteps, num_nodes);
160
216
161
- julia> layer = GConvGRU(d_in => d_out, 2);
217
+ julia> layer = GConvGRU(d_in => d_out, 2)
218
+ GConvGRU(
219
+ GConvGRUCell(2 => 3, 2), # 108 parameters
220
+ ) # Total: 12 arrays, 108 parameters, 1.148 KiB.
162
221
163
222
julia> y = layer(g, x);
164
223
165
224
julia> size(y) # (d_out, timesteps, num_nodes)
166
225
(3, 5, 5)
167
226
```
168
227
"""
169
- struct GConvGRU{G<: GConvGRUCell } <: GNNLayer
170
- cell:: G
171
- end
172
-
173
- Flux. @layer GConvGRU
174
-
175
- function GConvGRU (ch:: Pair{Int,Int} , k:: Int ; kws... )
176
- return GConvGRU (GConvGRUCell (ch, k; kws... ))
177
- end
178
-
179
- Flux. initialstates (rnn:: GConvGRU ) = Flux. initialstates (rnn. cell)
180
-
181
- function (rnn:: GConvGRU )(g:: GNNGraph , x:: AbstractArray )
182
- return scan (rnn. cell, g, x, initialstates (rnn))
183
- end
184
-
185
- function Base. show (io:: IO , rnn:: GConvGRU )
186
- print (io, " GConvGRU($(rnn. cell. in) => $(rnn. cell. out) , $(rnn. cell. k) )" )
187
- end
228
+ GConvGRU (args... ; kws... ) = GNNRecurrence (GConvGRUCell (args... ; kws... ))
188
229
189
230
190
231
"""
@@ -268,7 +309,7 @@ julia> size(y) # (d_out, num_nodes)
268
309
out:: Int
269
310
end
270
311
271
- Flux. @layer GConvLSTMCell
312
+ Flux. @layer :noexpand GConvLSTMCell
272
313
273
314
function GConvLSTMCell (ch:: Pair{Int, Int} , k:: Int ;
274
315
bias:: Bool = true ,
@@ -305,6 +346,8 @@ function Flux.initialstates(cell::GConvLSTMCell)
305
346
(zeros_like (cell. conv_x_i. weight, cell. out), zeros_like (cell. conv_x_i. weight, cell. out))
306
347
end
307
348
349
+ (cell:: GConvLSTMCell )(g:: GNNGraph , x:: AbstractMatrix ) = cell (g, x, initialstates (cell))
350
+
308
351
function (cell:: GConvLSTMCell )(g:: GNNGraph , x:: AbstractMatrix , (h, c))
309
352
if h isa AbstractVector
310
353
h = repeat (h, 1 , g. num_nodes)
@@ -334,29 +377,74 @@ end
334
377
335
378
336
379
"""
337
- GConvLSTM(in => out, k; kws...)
380
+ GConvLSTM(args...; kws...)
381
+
382
+ Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell.
383
+ It can be used to process an entire temporal sequence of node features at once.
384
+
385
+ The arguments are passed to the [`GConvLSTMCell`](@ref) constructor.
386
+ See [`GNNRecurrence`](@ref) for more details.
387
+
388
+ # Examples
389
+
390
+ ```jldoctest
391
+ julia> num_nodes, num_edges = 5, 10;
392
+
393
+ julia> d_in, d_out = 2, 3;
394
+
395
+ julia> timesteps = 5;
396
+
397
+ julia> g = rand_graph(num_nodes, num_edges);
398
+
399
+ julia> x = rand(Float32, d_in, timesteps, num_nodes);
400
+
401
+ julia> layer = GConvLSTM(d_in => d_out, 2)
402
+ GNNRecurrence(
403
+ GConvLSTMCell(2 => 3, 2), # 168 parameters
404
+ ) # Total: 24 arrays, 168 parameters, 2.023 KiB.
405
+
406
+ julia> y = layer(g, x);
407
+
408
+ julia> size(y) # (d_out, timesteps, num_nodes)
409
+ (3, 5, 5)
410
+ ```
411
+ """
412
+ GConvLSTM (args... ; kws... ) = GNNRecurrence (GConvLSTMCell (args... ; kws... ))
413
+
414
+ """
415
+ DCGRUCell(in => out, k; [bias, init])
416
+
417
+ Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper
418
+ [Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926).
419
+
420
+ Applyis a [`DConv`](@ref) layer to model spatial dependencies,
421
+ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
338
422
339
- The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell,
340
- used to process an entire temporal sequence of node features at once.
423
+ # Arguments
341
424
342
- The arguments are the same as for [`GConvLSTMCell`](@ref).
425
+ - `in`: Number of input node features.
426
+ - `out`: Number of output node features.
427
+ - `k`: Diffusion step for the `DConv`.
428
+ - `bias`: Add learnable bias. Default `true`.
429
+ - `init`: Weights' initializer. Default `glorot_uniform`.
343
430
344
431
# Forward
345
432
346
- layer (g::GNNGraph, x, [state ])
433
+ cell (g::GNNGraph, x, [h ])
347
434
348
435
- `g`: The input graph.
349
- - `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`.
350
- - `state`: The initial hidden state of the LSTM cell.
351
- If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`.
352
- If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
436
+ - `x`: The node features. It should be a matrix of size `in x num_nodes`.
437
+ - `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
438
+ If not provided, it is assumed to be a matrix of zeros.
353
439
354
- Applies the recurrent cell to each timestep of the input sequence and returns the output as
355
- an array of size `out x timesteps x num_nodes` .
440
+ Performs one recurrence step and returns a tuple `(h, h)`,
441
+ where `h` is the updated hidden state of the GRU cell .
356
442
357
443
# Examples
358
444
359
445
```jldoctest
446
+ julia> using GraphNeuralNetworks, Flux
447
+
360
448
julia> num_nodes, num_edges = 5, 10;
361
449
362
450
julia> d_in, d_out = 2, 3;
@@ -365,33 +453,98 @@ julia> timesteps = 5;
365
453
366
454
julia> g = rand_graph(num_nodes, num_edges);
367
455
368
- julia> x = rand(Float32, d_in, timesteps, num_nodes);
456
+ julia> x = [ rand(Float32, d_in, num_nodes) for t in 1:timesteps] ;
369
457
370
- julia> layer = GConvLSTM (d_in => d_out, 2);
458
+ julia> cell = DCGRUCell (d_in => d_out, 2);
371
459
372
- julia> y = layer(g, x );
460
+ julia> state = Flux.initialstates(cell );
373
461
374
- julia> size(y) # (d_out, timesteps, num_nodes)
375
- (3, 5, 5)
462
+ julia> y = state;
463
+
464
+ julia> for xt in x
465
+ y, state = cell(g, xt, state)
466
+ end
467
+
468
+ julia> size(y) # (d_out, num_nodes)
469
+ (3, 5)
376
470
```
377
- """
378
- struct GConvLSTM{G<: GConvLSTMCell } <: GNNLayer
379
- cell:: G
471
+ """
472
+ struct DCGRUCell
473
+ in:: Int
474
+ out:: Int
475
+ k:: Int
476
+ dconv_u:: DConv
477
+ dconv_r:: DConv
478
+ dconv_c:: DConv
380
479
end
381
480
382
- Flux. @layer GConvLSTM
481
+ Flux. @layer :noexpand DCGRUCell
383
482
384
- function GConvLSTM (ch:: Pair{Int,Int} , k:: Int ; kws... )
385
- return GConvLSTM (GConvLSTMCell (ch, k; kws... ))
483
+ function DCGRUCell (ch:: Pair{Int,Int} , k:: Int ; bias = true , init = glorot_uniform)
484
+ in, out = ch
485
+ dconv_u = DConv ((in + out) => out, k; bias, init)
486
+ dconv_r = DConv ((in + out) => out, k; bias, init)
487
+ dconv_c = DConv ((in + out) => out, k; bias, init)
488
+ return DCGRUCell (in, out, k, dconv_u, dconv_r, dconv_c)
386
489
end
387
490
388
- Flux. initialstates (rnn:: GConvLSTM ) = Flux. initialstates (rnn. cell)
491
+ Flux. initialstates (cell:: DCGRUCell ) = zeros_like (cell. dconv_u. weights, cell. out)
492
+
493
+ (cell:: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix ) = cell (g, x, initialstates (cell))
389
494
390
- function (rnn :: GConvLSTM )(g:: GNNGraph , x:: AbstractArray )
391
- return scan (rnn . cell, g, x, initialstates (rnn ))
495
+ function (cell :: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix , h :: AbstractVector )
496
+ return cell ( g, x, repeat (h, 1 , g . num_nodes ))
392
497
end
393
498
394
- function Base. show (io:: IO , rnn:: GConvLSTM )
395
- print (io, " GConvLSTM($(rnn. cell. in) => $(rnn. cell. out) , $(rnn. cell. k) )" )
499
+ function (cell:: DCGRUCell )(g:: GNNGraph , x:: AbstractMatrix , h:: AbstractMatrix )
500
+ h̃ = vcat (x, h)
501
+ z = cell. dconv_u (g, h̃)
502
+ z = NNlib. sigmoid_fast .(z)
503
+ r = cell. dconv_r (g, h̃)
504
+ r = NNlib. sigmoid_fast .(r)
505
+ ĥ = vcat (x, h .* r)
506
+ c = cell. dconv_c (g, ĥ)
507
+ c = NNlib. tanh_fast .(c)
508
+ h = z.* h + (1 .- z) .* c
509
+ return h, h
510
+ end
511
+
512
+ function Base. show (io:: IO , cell:: DCGRUCell )
513
+ print (io, " DCGRUCell($(cell. in) => $(cell. out) , $(cell. k) )" )
396
514
end
397
515
516
+ """
517
+ DCGRU(args...; kws...)
518
+
519
+ Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell.
520
+ It can be used to process an entire temporal sequence of node features at once.
521
+
522
+ The arguments are passed to the [`DCGRUCell`](@ref) constructor.
523
+ See [`GNNRecurrence`](@ref) for more details.
524
+
525
+ # Examples
526
+ ```jldoctest
527
+ julia> num_nodes, num_edges = 5, 10;
528
+
529
+ julia> d_in, d_out = 2, 3;
530
+
531
+ julia> timesteps = 5;
532
+
533
+ julia> g = rand_graph(num_nodes, num_edges);
534
+
535
+ julia> x = rand(Float32, d_in, timesteps, num_nodes);
536
+
537
+ julia> layer = DCGRU(d_in => d_out, 2)
538
+ GNNRecurrence(
539
+ DCGRUCell(2 => 3, 2), # 189 parameters
540
+ ) # Total: 6 arrays, 189 parameters, 1.184 KiB.
541
+
542
+ julia> y = layer(g, x);
543
+
544
+ julia> size(y) # (d_out, timesteps, num_nodes)
545
+ (3, 5, 5)
546
+ ```
547
+ """
548
+ DCGRU (args... ; kws... ) = GNNRecurrence (DCGRUCell (args... ; kws... ))
549
+
550
+
0 commit comments