Skip to content

Commit cb82352

Browse files
authored
[GNNLux] Add A3TGCN temporal layer (#485)
* Export A3TGCN * Add struct * Add test A#TGCN * Fix test
1 parent ed78e88 commit cb82352

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ export AGNNConv,
4141

4242
include("layers/temporalconv.jl")
4343
export TGCN
44+
export A3TGCN
4445

4546
end #module
4647

GNNLux/src/layers/temporalconv.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,40 @@ function Base.show(io::IO, tgcn::TGCNCell)
5656
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
5757
end
5858

59-
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
59+
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
60+
61+
@concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)}
62+
in_dims::Int
63+
out_dims::Int
64+
tgcn
65+
dense1
66+
dense2
67+
end
68+
69+
function A3TGCN(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
70+
in_dims, out_dims = ch
71+
tgcn = TGCN(ch; use_bias, init_weight, init_state, init_bias, add_self_loops, use_edge_weight)
72+
dense1 = Dense(out_dims, out_dims)
73+
dense2 = Dense(out_dims, out_dims)
74+
return A3TGCN(in_dims, out_dims, tgcn, dense1, dense2)
75+
end
76+
77+
function (l::A3TGCN)(g, x, ps, st)
78+
dense1 = StatefulLuxLayer{true}(l.dense1, ps.dense1, _getstate(st, :dense1))
79+
dense2 = StatefulLuxLayer{true}(l.dense2, ps.dense2, _getstate(st, :dense2))
80+
h, st = l.tgcn(g, x, ps.tgcn, st.tgcn)
81+
x = dense1(h)
82+
x = dense2(x)
83+
a = NNlib.softmax(x, dims = 3)
84+
c = sum(a .* h , dims = 3)
85+
if length(size(c)) == 3
86+
c = dropdims(c, dims = 3)
87+
end
88+
return c, st
89+
end
90+
91+
LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)
92+
93+
function Base.show(io::IO, l::A3TGCN)
94+
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
95+
end

GNNLux/test/layers/temporalconv_test.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
33

44
rng = StableRNG(1234)
5-
g = rand_graph(10, 40, seed=1234)
5+
g = rand_graph(rng, 10, 40)
66
x = randn(rng, Float32, 3, 10)
77

88
@testset "TGCN" begin
@@ -12,4 +12,12 @@
1212
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
1313
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
1414
end
15+
16+
@testset "A3TGCN" begin
17+
l = A3TGCN(3=>3)
18+
ps = LuxCore.initialparameters(rng, l)
19+
st = LuxCore.initialstates(rng, l)
20+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
21+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
22+
end
1523
end

0 commit comments

Comments
 (0)