Skip to content

Commit b229ab2

Browse files
EvolveGCNO
1 parent 13142f2 commit b229ab2

File tree

1 file changed

+106
-9
lines changed

1 file changed

+106
-9
lines changed

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,42 @@ end
2222
"""
2323
GNNRecurrence(cell)
2424
25-
Construct a recurrent layer that applies the `cell`
26-
to process an entire temporal sequence of node features at once.
25+
Construct a recurrent layer that applies the graph recurrent `cell` forward
26+
multiple times to process an entire temporal sequence of node features at once.
27+
28+
The `cell` has to satisfy the following interface for the forward pass:
29+
`yt, state = cell(g, xt, state)`, where `xt` is the input node features,
30+
`yt` is the updated node features, `state` is the cell state to be updated.
2731
2832
# Forward
2933
30-
layer(g::GNNGraph, x, [state])
34+
layer(g, x, [state])
3135
32-
- `g`: The input graph.
33-
- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
34-
- `state`: The current state of the cell.
36+
Applies the recurrent cell to each timestep of the input sequence.
37+
38+
## Arguments
39+
40+
- `g`: The input `GNNGraph` or `TemporalSnapshotsGNNGraph`.
41+
- If `GNNGraph`, the same graph is used for all timesteps.
42+
- If `TemporalSnapshotsGNNGraph`, a different graph is used for each timestep. Not all cells support this.
43+
- `x`: The time-varying node features.
44+
- If `g` is `GNNGraph`, it is an array of size `in x timesteps x num_nodes`.
45+
- If `g` is `TemporalSnapshotsGNNGraph`, it is an vector of length `timesteps`,
46+
with element `t` of size `in x num_nodes_t`.
47+
- `state`: The initial state for the cell.
3548
If not provided, it is generated by calling `Flux.initialstates(cell)`.
3649
37-
Applies the recurrent cell to each timestep of the input sequence and returns the output as
38-
an array of size `out_features x timesteps x num_nodes`.
50+
## Return
51+
52+
Returns the updated node features:
53+
- If `g` is `GNNGraph`, returns an array of size `out_features x timesteps x num_nodes`.
54+
- If `g` is `TemporalSnapshotsGNNGraph`, returns a vector of length `timesteps`,
55+
with element `t` of size `out_features x num_nodes_t`.
3956
4057
# Examples
4158
59+
The following example considers a static graph and a time-varying node features.
60+
4261
```jldoctest
4362
julia> num_nodes, num_edges = 5, 10;
4463
@@ -47,6 +66,9 @@ julia> d_in, d_out = 2, 3;
4766
julia> timesteps = 5;
4867
4968
julia> g = rand_graph(num_nodes, num_edges);
69+
GNNGraph:
70+
num_nodes: 5
71+
num_edges: 10
5072
5173
julia> x = rand(Float32, d_in, timesteps, num_nodes);
5274
@@ -63,6 +85,38 @@ julia> y = layer(g, x);
6385
julia> size(y) # (d_out, timesteps, num_nodes)
6486
(3, 5, 5)
6587
```
88+
Now consider a time-varying graph and time-varying node features.
89+
```jldoctest
90+
julia> d_in, d_out = 2, 3;
91+
92+
julia> timesteps = 5;
93+
94+
julia> num_nodes = [10, 10, 10, 10, 10];
95+
96+
julia> num_edges = [10, 12, 14, 16, 18];
97+
98+
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
99+
100+
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
101+
102+
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
103+
104+
julia> cell = EvolveGCNOCell(d_in => d_out)
105+
EvolveGCNOCell(2 => 3) # 321 parameters
106+
107+
julia> layer = GNNRecurrence(cell)
108+
GNNRecurrence(
109+
EvolveGCNOCell(2 => 3), # 321 parameters
110+
) # Total: 5 arrays, 321 parameters, 1.535 KiB.
111+
112+
julia> y = layer(tg, x);
113+
114+
julia> length(y) # timesteps
115+
5
116+
117+
julia> size(y[end]) # (d_out, num_nodes[end])
118+
(3, 10)
119+
```
66120
"""
67121
struct GNNRecurrence{G} <: GNNLayer
68122
cell::G
@@ -437,7 +491,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen
437491
- `out`: Number of output node features.
438492
- `k`: Diffusion step for the `DConv`.
439493
- `bias`: Add learnable bias. Default `true`.
440-
- `init`: Weights' initializer. Default `glorot_uniform`.
494+
- `init`: Convolution weights' initializer. Default `glorot_uniform`.
441495
442496
# Forward
443497
@@ -651,3 +705,46 @@ end
651705
function Base.show(io::IO, egcno::EvolveGCNOCell)
652706
print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))")
653707
end
708+
709+
710+
"""
711+
EvolveGCNO(args...; kws...)
712+
713+
Construct a recurrent layer corresponding to the [`EvolveGCNOCell`](@ref) cell.
714+
It can be used to process an entire temporal sequence of graphs and node features at once.
715+
716+
The arguments are passed to the [`EvolveGCNOCell`](@ref) constructor.
717+
See [`GNNRecurrence`](@ref) for more details.
718+
719+
# Examples
720+
721+
```jldoctest
722+
julia> d_in, d_out = 2, 3;
723+
724+
julia> timesteps = 5;
725+
726+
julia> num_nodes = [10, 10, 10, 10, 10];
727+
728+
julia> num_edges = [10, 12, 14, 16, 18];
729+
730+
julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)];
731+
732+
julia> tg = TemporalSnapshotsGNNGraph(snapshots)
733+
734+
julia> x = [rand(Float32, d_in, n) for n in num_nodes];
735+
736+
julia> cell = EvolveGCNO(d_in => d_out)
737+
GNNRecurrence(
738+
EvolveGCNOCell(2 => 3), # 321 parameters
739+
) # Total: 5 arrays, 321 parameters, 1.535 KiB.
740+
741+
julia> y = layer(tg, x);
742+
743+
julia> length(y) # timesteps
744+
5
745+
746+
julia> size(y[end]) # (d_out, num_nodes[end])
747+
(3, 10)
748+
```
749+
"""
750+
EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...))

0 commit comments

Comments
 (0)