|
438 | 438 |
|
439 | 439 | LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,) |
440 | 440 |
|
| 441 | +""" |
| 442 | + DCGRU(in => out, k; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32) |
| 443 | +
|
| 444 | +Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural |
| 445 | +Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). |
| 446 | +
|
| 447 | +Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. |
| 448 | +
|
| 449 | +# Arguments |
| 450 | +
|
| 451 | +- `in`: Number of input features. |
| 452 | +- `out`: Number of output features. |
| 453 | +- `k`: Diffusion step. |
| 454 | +- `use_bias`: Add learnable bias. Default `true`. |
| 455 | +- `init_weight`: Weights' initializer. Default `glorot_uniform`. |
| 456 | +- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. |
| 457 | +- `init_bias`: Bias initializer. Default `zeros32`. |
| 458 | +
|
| 459 | +# Examples |
| 460 | +
|
| 461 | +```julia |
| 462 | +using GNNLux, Lux, Random |
| 463 | +
|
| 464 | +# initialize random number generator |
| 465 | +rng = Random.default_rng() |
| 466 | +
|
| 467 | +# create data |
| 468 | +g = rand_graph(rng, 5, 10) |
| 469 | +x = rand(rng, Float32, 2, 5) |
| 470 | +
|
| 471 | +# create layer |
| 472 | +l = DCGRU(2 => 5, 2) |
| 473 | +
|
| 474 | +# setup layer |
| 475 | +ps, st = LuxCore.setup(rng, l) |
| 476 | +
|
| 477 | +# forward pass |
| 478 | +y, st = l(g, x, ps, st) # result size (5, 5) |
| 479 | +``` |
| 480 | +""" |
441 | 481 | DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...)) |
442 | 482 |
|
443 | 483 | @concrete struct EvolveGCNO <: GNNLayer |
|
0 commit comments