Skip to content

Commit 383d464

Browse files
EvolveGCNOCell
1 parent cfd9ec3 commit 383d464

File tree

3 files changed

+206
-113
lines changed

3 files changed

+206
-113
lines changed

GraphNeuralNetworks/src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ export GNNRecurrence,
5454
GConvGRU, GConvGRUCell,
5555
GConvLSTM, GConvLSTMCell,
5656
DCGRU, DCGRUCell,
57+
EvolveGCNO, EvolveGCNOCell,
5758
TGCN,
58-
A3TGCN,
59-
EvolveGCNO
59+
A3TGCN
6060

6161
include("layers/pool.jl")
6262
export GlobalPool,

GraphNeuralNetworks/src/layers/temporalconv.jl

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
22
y = []
3-
for x_t in eachslice(x, dims = 2)
4-
yt, state = cell(g, x_t, state)
3+
for xt in eachslice(x, dims = 2)
4+
yt, state = cell(g, xt, state)
55
y = vcat(y, [yt])
66
end
77
return stack(y, dims = 2)
88
end
99

10+
function scan(cell, tg::TemporalSnapshotsGNNGraph, x::AbstractVector, state)
11+
# @assert length(x) == length(tg.snapshots)
12+
y = []
13+
for (t, xt) in enumerate(x)
14+
gt = tg.snapshots[t]
15+
yt, state = cell(gt, xt, state)
16+
y = vcat(y, [yt])
17+
end
18+
return y
19+
end
20+
1021

1122
"""
1223
GNNRecurrence(cell)
@@ -20,7 +31,7 @@ to process an entire temporal sequence of node features at once.
2031
2132
- `g`: The input graph.
2233
- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`.
23-
- `state`: The initial state of the cell.
34+
- `state`: The current state of the cell.
2435
If not provided, it is generated by calling `Flux.initialstates(cell)`.
2536
2637
Applies the recurrent cell to each timestep of the input sequence and returns the output as
@@ -61,11 +72,11 @@ Flux.@layer GNNRecurrence
6172

6273
Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell)
6374

64-
function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}) where {T}
75+
function (rnn::GNNRecurrence)(g, x)
6576
return rnn(g, x, initialstates(rnn))
6677
end
6778

68-
function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}, state) where {T}
79+
function (rnn::GNNRecurrence)(g, x, state) where {T}
6980
return scan(rnn.cell, g, x, state)
7081
end
7182

@@ -97,7 +108,7 @@ followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies.
97108
98109
- `g`: The input graph.
99110
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
100-
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
111+
- `h`: The current hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
101112
If not provided, it is assumed to be a matrix of zeros.
102113
103114
Performs one recurrence step and returns a tuple `(h, h)`,
@@ -251,9 +262,9 @@ followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies.
251262
252263
- `g`: The input graph.
253264
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
254-
- `state`: The initial hidden state of the LSTM cell.
265+
- `state`: The current state of the LSTM cell.
255266
If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`.
256-
If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros.
267+
If not provided, it is assumed to be a tuple of matrices of zeros.
257268
258269
Performs one recurrence step and returns a tuple `(output, state)`,
259270
where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`.
@@ -434,7 +445,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen
434445
435446
- `g`: The input graph.
436447
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
437-
- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`.
448+
- `h`: The current state of the GRU cell. It is a matrix of size `out x num_nodes`.
438449
If not provided, it is assumed to be a matrix of zeros.
439450
440451
Performs one recurrence step and returns a tuple `(h, h)`,
@@ -547,4 +558,96 @@ julia> size(y) # (d_out, timesteps, num_nodes)
547558
"""
548559
DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...))
549560

561+
""""
562+
EvolveGCNOCell(in => out; bias = true, init = glorot_uniform)
563+
564+
Evolving Graph Convolutional Network cell of type "-O" from the paper
565+
[EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191).
566+
567+
Uses a [`GCNConv`](@ref) layer to model spatial dependencies, and an `LSTMCell` to model temporal dependencies.
568+
Can work with time-varying graphs and node features.
569+
570+
# Arguments
571+
572+
- `in => out`: A pair where `in` is the number of input node features and `out`
573+
is the number of output node features.
574+
- `bias`: Add learnable bias for the convolution and the lstm cell. Default `true`.
575+
- `init`: Weights' initializer for the convolution. Default `glorot_uniform`.
576+
577+
# Forward
578+
579+
cell(g::GNNGraph, x, [state]) -> x, state
580+
581+
- `g`: The input graph.
582+
- `x`: The node features. It should be a matrix of size `in x num_nodes`.
583+
- `state`: The current state of the cell.
584+
A state is a tuple `(weight, lstm)` where `weight` is the convolution's weight and `lstm` is the lstm's state.
585+
If not provided, it is generated by calling `Flux.initialstates(cell)`.
586+
587+
Returns the updated node features `x` and the updated state.
588+
589+
```jldoctest
590+
julia> using GraphNeuralNetworks, Flux
591+
592+
julia> num_nodes, num_edges = 5, 10;
593+
594+
julia> d_in, d_out = 2, 3;
595+
596+
julia> timesteps = 5;
597+
598+
julia> g = [rand_graph(num_nodes, num_edges) for t in 1:timesteps];
599+
600+
julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps];
601+
602+
julia> cell1 = EvolveGCNOCell(d_in => d_out)
603+
EvolveGCNOCell(2 => 3) # 321 parameters
604+
605+
julia> cell2 = EvolveGCNOCell(d_out => d_out)
606+
EvolveGCNOCell(3 => 3) # 696 parameters
607+
608+
julia> state1 = Flux.initialstates(cell1);
609+
610+
julia> state2 = Flux.initialstates(cell2);
611+
612+
julia> outputs = [];
613+
614+
julia> for t in 1:timesteps
615+
zt, state1 = cell1(g[t], x[t], state1)
616+
yt, state2 = cell2(g[t], zt, state2)
617+
outputs = vcat(outputs, [yt])
618+
end
619+
620+
julia> size(outputs[end]) # (d_out, num_nodes)
621+
(3, 5)
622+
```
623+
"""
624+
struct EvolveGCNOCell{C,L} <: GNNLayer
625+
in::Int
626+
out::Int
627+
conv::C
628+
lstm::L
629+
end
630+
631+
Flux.@layer :noexpand EvolveGCNOCell
550632

633+
function EvolveGCNOCell((in,out)::Pair{Int,Int}; bias = true, init = glorot_uniform)
634+
conv = GCNConv(in => out; bias, init)
635+
lstm = LSTMCell(in*out => in*out; bias)
636+
return EvolveGCNOCell(in, out, conv, lstm)
637+
end
638+
639+
function Flux.initialstates(cell::EvolveGCNOCell)
640+
weight = reshape(cell.conv.weight, :)
641+
lstm = Flux.initialstates(cell.lstm)
642+
return (; weight, lstm)
643+
end
644+
645+
function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state)
646+
weight, state_lstm = cell.lstm(state.weight, state.lstm)
647+
x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in)))
648+
return x, (; weight, lstm = state_lstm)
649+
end
650+
651+
function Base.show(io::IO, egcno::EvolveGCNOCell)
652+
print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))")
653+
end

GraphNeuralNetworks/src/layers/temporalconv_old.jl

Lines changed: 92 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
function scan(cell, g, x, state)
2-
y = []
3-
for x_t in eachslice(x, dims = 2)
4-
yt, state = cell(g, x_t, state)
5-
y = vcat(y, [yt])
6-
end
7-
return stack(y, dims = 2)
8-
end
9-
10-
111
struct TGCNCell{C,G} <: GNNLayer
122
conv::C
133
gru::G
@@ -179,97 +169,97 @@ function Base.show(io::IO, a3tgcn::A3TGCN)
179169
end
180170

181171

182-
# """
183-
# EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
184-
185-
# Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).
186-
187-
# Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.
188-
189-
190-
# # Arguments
191-
192-
# - `in`: Number of input features.
193-
# - `out`: Number of output features.
194-
# - `bias`: Add learnable bias. Default `true`.
195-
# - `init`: Weights' initializer. Default `glorot_uniform`.
196-
# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
197-
198-
# # Examples
199-
200-
# ```jldoctest
201-
# julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
202-
# TemporalSnapshotsGNNGraph:
203-
# num_nodes: [10, 10, 10]
204-
# num_edges: [20, 14, 22]
205-
# num_snapshots: 3
206-
207-
# julia> ev = EvolveGCNO(4 => 5)
208-
# EvolveGCNO(4 => 5)
209-
210-
# julia> size(ev(tg, tg.ndata.x))
211-
# (3,)
212-
213-
# julia> size(ev(tg, tg.ndata.x)[1])
214-
# (5, 10)
215-
# ```
216-
# """
217-
# struct EvolveGCNO
218-
# conv
219-
# W_init
220-
# init_state
221-
# in::Int
222-
# out::Int
223-
# Wf
224-
# Uf
225-
# Bf
226-
# Wi
227-
# Ui
228-
# Bi
229-
# Wo
230-
# Uo
231-
# Bo
232-
# Wc
233-
# Uc
234-
# Bc
235-
# end
172+
"""
173+
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
174+
175+
Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).
176+
177+
Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.
178+
179+
180+
# Arguments
181+
182+
- `in`: Number of input features.
183+
- `out`: Number of output features.
184+
- `bias`: Add learnable bias. Default `true`.
185+
- `init`: Weights' initializer. Default `glorot_uniform`.
186+
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
187+
188+
# Examples
189+
190+
```jldoctest
191+
julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
192+
TemporalSnapshotsGNNGraph:
193+
num_nodes: [10, 10, 10]
194+
num_edges: [20, 14, 22]
195+
num_snapshots: 3
196+
197+
julia> ev = EvolveGCNO(4 => 5)
198+
EvolveGCNO(4 => 5)
199+
200+
julia> size(ev(tg, tg.ndata.x))
201+
(3,)
202+
203+
julia> size(ev(tg, tg.ndata.x)[1])
204+
(5, 10)
205+
```
206+
"""
207+
struct EvolveGCNO
208+
conv
209+
W_init
210+
init_state
211+
in::Int
212+
out::Int
213+
Wf
214+
Uf
215+
Bf
216+
Wi
217+
Ui
218+
Bi
219+
Wo
220+
Uo
221+
Bo
222+
Wc
223+
Uc
224+
Bc
225+
end
236226

237-
# function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
238-
# in, out = ch
239-
# W = init(out, in)
240-
# conv = GCNConv(ch; bias = bias, init = init)
241-
# Wf = init(out, in)
242-
# Uf = init(out, in)
243-
# Bf = bias ? init(out, in) : nothing
244-
# Wi = init(out, in)
245-
# Ui = init(out, in)
246-
# Bi = bias ? init(out, in) : nothing
247-
# Wo = init(out, in)
248-
# Uo = init(out, in)
249-
# Bo = bias ? init(out, in) : nothing
250-
# Wc = init(out, in)
251-
# Uc = init(out, in)
252-
# Bc = bias ? init(out, in) : nothing
253-
# return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
254-
# end
255-
256-
# function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
257-
# H = egcno.init_state(egcno.out, egcno.in)
258-
# C = egcno.init_state(egcno.out, egcno.in)
259-
# W = egcno.W_init
260-
# X = map(1:tg.num_snapshots) do i
261-
# F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
262-
# I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
263-
# O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
264-
# C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
265-
# C = F .* C + I .* C̃
266-
# H = O .* tanh_fast.(C)
267-
# W = H
268-
# egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
269-
# end
270-
# return X
271-
# end
227+
function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
228+
in, out = ch
229+
W = init(out, in)
230+
conv = GCNConv(ch; bias = bias, init = init)
231+
Wf = init(out, in)
232+
Uf = init(out, in)
233+
Bf = bias ? init(out, in) : nothing
234+
Wi = init(out, in)
235+
Ui = init(out, in)
236+
Bi = bias ? init(out, in) : nothing
237+
Wo = init(out, in)
238+
Uo = init(out, in)
239+
Bo = bias ? init(out, in) : nothing
240+
Wc = init(out, in)
241+
Uc = init(out, in)
242+
Bc = bias ? init(out, in) : nothing
243+
return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
244+
end
245+
246+
function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
247+
H = egcno.init_state(egcno.out, egcno.in)
248+
C = egcno.init_state(egcno.out, egcno.in)
249+
W = egcno.W_init
250+
X = map(1:tg.num_snapshots) do i
251+
F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
252+
I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
253+
O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
254+
= Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
255+
C = F .* C + I .*
256+
H = O .* tanh_fast.(C)
257+
W = H
258+
egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
259+
end
260+
return X
261+
end
272262

273-
# function Base.show(io::IO, egcno::EvolveGCNO)
274-
# print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
275-
# end
263+
function Base.show(io::IO, egcno::EvolveGCNO)
264+
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
265+
end

0 commit comments

Comments
 (0)