Skip to content

Commit 7319f4d

Browse files
[GNNLux] TGCN temporal layer (#470)
* First draft * Fix signature * Improvement * Export TGCN * Fixes * Back to previous version * Add test * Remove GNNlib code * Fix * Fix Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 87f3c60 commit 7319f4d

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ export AGNNConv,
3737
SGConv
3838
# TAGConv,
3939
# TransformerConv
40-
40+
41+
include("layers/temporalconv.jl")
42+
export TGCN
4143

4244
end #module
4345

GNNLux/src/layers/temporalconv.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
2+
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
3+
end
4+
5+
function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell)
6+
return (cell=LuxCore.initialstates(rng, r.cell), carry=nothing)
7+
end
8+
9+
function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple)
10+
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
11+
return out, (; cell=st, carry)
12+
end
13+
14+
function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple)
15+
st, carry = st.cell, st.carry
16+
for xᵢ in x
17+
(out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry)
18+
end
19+
return out, (; cell=st, carry)
20+
end
21+
22+
function applyrecurrentcell(l, g, x, ps, st, carry)
23+
return Lux.apply(l, g, (x, carry), ps, st)
24+
end
25+
26+
LuxCore.apply(m::GNNContainerLayer, g, x, ps, st) = m(g, x, ps, st)
27+
28+
@concrete struct TGCNCell <: GNNContainerLayer{(:conv, :gru)}
29+
in_dims::Int
30+
out_dims::Int
31+
conv
32+
gru
33+
init_state::Function
34+
end
35+
36+
function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
37+
in_dims, out_dims = ch
38+
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
39+
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
40+
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
41+
end
42+
43+
function (l::TGCNCell)(g, (x, h), ps, st)
44+
if h === nothing
45+
h = l.init_state(l.out_dims, 1)
46+
end
47+
x̃, stconv = l.conv(g, x, ps.conv, st.conv)
48+
(h, (h,)), stgru = l.gru((x̃,(h,)), ps.gru,st.gru)
49+
return (h, h), (conv=stconv, gru=stgru)
50+
end
51+
52+
LuxCore.outputsize(l::TGCNCell) = (l.out_dims,)
53+
LuxCore.outputsize(l::GNNLux.StatefulRecurrentCell) = (l.cell.out_dims,)
54+
55+
function Base.show(io::IO, tgcn::TGCNCell)
56+
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
57+
end
58+
59+
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@testitem "layers/temporalconv" setup=[SharedTestSetup] begin
2+
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
3+
4+
rng = StableRNG(1234)
5+
g = rand_graph(10, 40, seed=1234)
6+
x = randn(rng, Float32, 3, 10)
7+
8+
@testset "TGCN" begin
9+
l = TGCN(3=>3)
10+
ps = LuxCore.initialparameters(rng, l)
11+
st = LuxCore.initialstates(rng, l)
12+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
13+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
14+
end
15+
end

GNNlib/src/GNNlib.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export agnn_conv,
6161
transformer_conv
6262

6363
include("layers/temporalconv.jl")
64-
export a3tgcn_conv
64+
export tgcn_conv
6565

6666
include("layers/pool.jl")
6767
export global_pool,

0 commit comments

Comments
 (0)