|
1 |
| -""" |
2 |
| - TGCNCell(in => out; [bias, init, add_self_loops, use_edge_weight]) |
3 |
| -
|
4 |
| -Temporal Graph Convolutional Network (T-GCN) cell from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). |
5 |
| -
|
6 |
| -Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. |
7 |
| -
|
8 |
| -# Arguments |
| 1 | +# Adapting Flux.Recur to work with GNNGraphs |
| 2 | +function (m::Flux.Recur)(g::GNNGraph, x) |
| 3 | + m.state, y = m.cell(m.state, g, x) |
| 4 | + return y |
| 5 | +end |
| 6 | + |
| 7 | +function (m::Flux.Recur)(g::GNNGraph, x::AbstractArray{T, 3}) where T |
| 8 | + h = [m(g, x_t) for x_t in Flux.eachlastdim(x)] |
| 9 | + sze = size(h[1]) |
| 10 | + reshape(reduce(hcat, h), sze[1], sze[2], length(h)) |
| 11 | +end |
9 | 12 |
|
10 |
| -- `in`: Number of input features. |
11 |
| -- `out`: Number of output features. |
12 |
| -- `bias`: Add learnable bias. Default `true`. |
13 |
| -- `init`: Weights' initializer. Default `glorot_uniform`. |
14 |
| -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. |
15 |
| -- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). |
16 |
| - If `add_self_loops=true` the new weights will be set to 1. |
17 |
| - This option is ignored if the `edge_weight` is explicitly provided in the forward pass. |
18 |
| - Default `false`. |
19 |
| -""" |
20 | 13 | struct TGCNCell <: GNNLayer
|
21 | 14 | conv::GCNConv
|
22 | 15 | gru::Flux.GRUv3Cell
|
|
50 | 43 | function Base.show(io::IO, tgcn::TGCNCell)
|
51 | 44 | print(io, "TGCNCell($(tgcn.in) => $(tgcn.out))")
|
52 | 45 | end
|
| 46 | + |
| 47 | +""" |
| 48 | + TGCN(in => out; [bias, init, add_self_loops, use_edge_weight]) |
| 49 | +
|
| 50 | +Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). |
| 51 | +
|
| 52 | +Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. |
| 53 | +
|
| 54 | +# Arguments |
| 55 | +
|
| 56 | +- `in`: Number of input features. |
| 57 | +- `out`: Number of output features. |
| 58 | +- `bias`: Add learnable bias. Default `true`. |
| 59 | +- `init`: Weights' initializer. Default `glorot_uniform`. |
| 60 | +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. |
| 61 | +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). |
| 62 | + If `add_self_loops=true` the new weights will be set to 1. |
| 63 | + This option is ignored if the `edge_weight` is explicitly provided in the forward pass. |
| 64 | + Default `false`. |
| 65 | +# Examples |
| 66 | +
|
| 67 | +```jldoctest |
| 68 | +julia> tgcn = TGCN(2 => 6) |
| 69 | +Recur( |
| 70 | + TGCNCell( |
| 71 | + GCNConv(2 => 6, σ), # 18 parameters |
| 72 | + GRUv3Cell(6 => 6), # 240 parameters |
| 73 | + Float32[0.0; 0.0; … ; 0.0; 0.0;;], # 6 parameters (all zero) |
| 74 | + 2, |
| 75 | + 6, |
| 76 | + ), |
| 77 | +) # Total: 8 trainable arrays, 264 parameters, |
| 78 | + # plus 1 non-trainable, 6 parameters, summarysize 1.492 KiB. |
| 79 | +
|
| 80 | +julia> g, x = rand_graph(5, 10), rand(Float32, 2, 5); |
| 81 | +
|
| 82 | +julia> y = tgcn(g, x); |
| 83 | +
|
| 84 | +julia> size(y) |
| 85 | +(6, 5) |
| 86 | +
|
| 87 | +julia> Flux.reset!(tgcn); |
| 88 | +
|
| 89 | +julia> tgcn(rand_graph(5, 10), rand(Float32, 2, 5, 20)) |> size # batch size of 20 |
| 90 | +(6, 5, 20) |
| 91 | +``` |
| 92 | +
|
| 93 | +!!! warning "Batch size changes" |
| 94 | + Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. |
| 95 | +""" |
| 96 | +TGCN(ch; kwargs...) = Flux.Recur(TGCNCell(ch; kwargs...)) |
| 97 | + |
| 98 | +Flux.Recur(tgcn::TGCNCell) = Flux.Recur(tgcn, tgcn.state0) |
| 99 | + |
| 100 | +# make TGCN compatible with GNNChain |
| 101 | +(l::Flux.Recur{TGCNCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) |
| 102 | +_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph, x) = l(g, x) |
| 103 | +_applylayer(l::Flux.Recur{TGCNCell}, g::GNNGraph) = l(g) |
0 commit comments