Skip to content

Commit 2313a96

Browse files
authored
Add EvolveGCNO temporal layer (#489)
* First draft * Add test * Add export `EvolveGCNO` * Improve `EvolveGCNO` * Ecport `EvolveGCNo` * Add `EvolveGCNO` * Fix * Add `EvolveGCNO` test * Fix
1 parent c896eda commit 2313a96

File tree

6 files changed

+178
-2
lines changed

6 files changed

+178
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ export TGCN,
4545
A3TGCN,
4646
GConvGRU,
4747
GConvLSTM,
48-
DCGRU
48+
DCGRU,
49+
EvolveGCNO
4950

5051
end #module
5152

GNNLux/src/layers/temporalconv.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,63 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)
274274

275275
DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))
276276

277+
@concrete struct EvolveGCNO <: GNNLayer
278+
in_dims::Int
279+
out_dims::Int
280+
use_bias::Bool
281+
init_weight
282+
init_state::Function
283+
init_bias
284+
end
285+
286+
function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
287+
in_dims, out_dims = ch
288+
return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias)
289+
end
290+
291+
function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO)
292+
weight = l.init_weight(rng, l.out_dims, l.in_dims)
293+
Wf = l.init_weight(rng, l.out_dims, l.in_dims)
294+
Uf = l.init_weight(rng, l.out_dims, l.in_dims)
295+
Wi = l.init_weight(rng, l.out_dims, l.in_dims)
296+
Ui = l.init_weight(rng, l.out_dims, l.in_dims)
297+
Wo = l.init_weight(rng, l.out_dims, l.in_dims)
298+
Uo = l.init_weight(rng, l.out_dims, l.in_dims)
299+
Wc = l.init_weight(rng, l.out_dims, l.in_dims)
300+
Uc = l.init_weight(rng, l.out_dims, l.in_dims)
301+
if l.use_bias
302+
bias = l.init_bias(rng, l.out_dims)
303+
Bf = l.init_bias(rng, l.out_dims, l.in_dims)
304+
Bi = l.init_bias(rng, l.out_dims, l.in_dims)
305+
Bo = l.init_bias(rng, l.out_dims, l.in_dims)
306+
Bc = l.init_bias(rng, l.out_dims, l.in_dims)
307+
return (; conv = (; weight, bias), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc, Bf, Bi, Bo, Bc))
308+
else
309+
return (; conv = (; weight), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc))
310+
end
311+
end
312+
313+
function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO)
314+
h = l.init_state(rng, l.out_dims, l.in_dims)
315+
c = l.init_state(rng, l.out_dims, l.in_dims)
316+
return (; conv = (;), lstm = (; h, c))
317+
end
318+
319+
function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple)
320+
H, C = st.lstm
321+
W = ps.conv.weight
322+
m = (; ps.conv.weight, bias = _getbias(ps),
323+
add_self_loops =true, use_edge_weight=true, σ = identity)
324+
325+
X = map(1:tg.num_snapshots) do i
326+
F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf)
327+
I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi)
328+
O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo)
329+
= NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc)
330+
C = F .* C + I .*
331+
H = O .* NNlib.tanh_fast.(C)
332+
W = H
333+
GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W)
334+
end
335+
return X, (; conv = (;), lstm = (h = H, c = C))
336+
end

GNNLux/test/layers/temporalconv_test.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
g = rand_graph(rng, 10, 40)
66
x = randn(rng, Float32, 3, 10)
77

8+
tg = TemporalSnapshotsGNNGraph([g for _ in 1:5])
9+
tx = [x for _ in 1:5]
10+
811
@testset "TGCN" begin
912
l = TGCN(3=>3)
1013
ps = LuxCore.initialparameters(rng, l)
@@ -44,4 +47,12 @@
4447
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
4548
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
4649
end
50+
51+
@testset "EvolveGCNO" begin
52+
l = EvolveGCNO(3=>3)
53+
ps = LuxCore.initialparameters(rng, l)
54+
st = LuxCore.initialstates(rng, l)
55+
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
56+
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
57+
end
4758
end

src/GraphNeuralNetworks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ export TGCN,
5454
A3TGCN,
5555
GConvLSTM,
5656
GConvGRU,
57-
DCGRU
57+
DCGRU,
58+
EvolveGCNO
5859

5960
include("layers/pool.jl")
6061
export GlobalPool,

src/layers/temporalconv.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,103 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
484484
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
485485
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)
486486

487+
"""
488+
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
489+
490+
Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).
491+
492+
Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.
493+
494+
495+
# Arguments
496+
497+
- `in`: Number of input features.
498+
- `out`: Number of output features.
499+
- `bias`: Add learnable bias. Default `true`.
500+
- `init`: Weights' initializer. Default `glorot_uniform`.
501+
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
502+
503+
# Examples
504+
505+
```jldoctest
506+
julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
507+
TemporalSnapshotsGNNGraph:
508+
num_nodes: [10, 10, 10]
509+
num_edges: [20, 14, 22]
510+
num_snapshots: 3
511+
512+
julia> ev = EvolveGCNO(4 => 5)
513+
EvolveGCNO(4 => 5)
514+
515+
julia> size(ev(tg, tg.ndata.x))
516+
(3,)
517+
518+
julia> size(ev(tg, tg.ndata.x)[1])
519+
(5, 10)
520+
```
521+
"""
522+
struct EvolveGCNO
523+
conv
524+
W_init
525+
init_state
526+
in::Int
527+
out::Int
528+
Wf
529+
Uf
530+
Bf
531+
Wi
532+
Ui
533+
Bi
534+
Wo
535+
Uo
536+
Bo
537+
Wc
538+
Uc
539+
Bc
540+
end
541+
542+
Flux.@functor EvolveGCNO
543+
544+
function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
545+
in, out = ch
546+
W = init(out, in)
547+
conv = GCNConv(ch; bias = bias, init = init)
548+
Wf = init(out, in)
549+
Uf = init(out, in)
550+
Bf = bias ? init(out, in) : nothing
551+
Wi = init(out, in)
552+
Ui = init(out, in)
553+
Bi = bias ? init(out, in) : nothing
554+
Wo = init(out, in)
555+
Uo = init(out, in)
556+
Bo = bias ? init(out, in) : nothing
557+
Wc = init(out, in)
558+
Uc = init(out, in)
559+
Bc = bias ? init(out, in) : nothing
560+
return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
561+
end
562+
563+
function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
564+
H = egcno.init_state(egcno.out, egcno.in)
565+
C = egcno.init_state(egcno.out, egcno.in)
566+
W = egcno.W_init
567+
X = map(1:tg.num_snapshots) do i
568+
F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
569+
I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
570+
O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
571+
= Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
572+
C = F .* C + I .*
573+
H = O .* tanh_fast.(C)
574+
W = H
575+
egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
576+
end
577+
return X
578+
end
579+
580+
function Base.show(io::IO, egcno::EvolveGCNO)
581+
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
582+
end
583+
487584
function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
488585
return l.(tg.snapshots, x)
489586
end

test/layers/temporalconv.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ end
6969
@test model(g1) isa GNNGraph
7070
end
7171

72+
@testset "EvolveGCNO" begin
73+
evolvegcno = EvolveGCNO(in_channel => out_channel)
74+
@test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S
75+
@test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N)
76+
end
77+
7278
@testset "GINConv" begin
7379
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
7480
@test length(ginconv(tg, tg.ndata.x)) == S

0 commit comments

Comments
 (0)