Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ export AGNNConv,
# TransformerConv

include("layers/temporalconv.jl")
export TGCN
export A3TGCN
export TGCN,
A3TGCN,
GConvGRU,
GConvLSTM,
DCGRU

end #module

181 changes: 181 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,184 @@ LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)
function Base.show(io::IO, l::A3TGCN)
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
end

@concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
in_dims::Int
out_dims::Int
k::Int
conv_x_r
conv_h_r
conv_x_z
conv_h_z
conv_x_h
conv_h_h
init_state::Function
end

function GConvGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
#reset gate
conv_x_r = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_r = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
#update gate
conv_x_z = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_z = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
#hidden state
conv_x_h = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_h = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
return GConvGRUCell(in_dims, out_dims, k, conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, init_state)
end

function (l::GConvGRUCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, g.num_nodes)
end
xr, st_conv_xr = l.conv_x_r(g, x, ps.conv_x_r, st.conv_x_r)
hr, st_conv_hr = l.conv_h_r(g, h, ps.conv_h_r, st.conv_h_r)
r = xr .+ hr
r = NNlib.sigmoid_fast(r)
xz, st_conv_x_z = l.conv_x_z(g, x, ps.conv_x_z, st.conv_x_z)
hz, st_conv_h_z = l.conv_h_z(g, h, ps.conv_h_z, st.conv_h_z)
z = xz .+ hz
z = NNlib.sigmoid_fast(z)
xh, st_conv_x_h = l.conv_x_h(g, x, ps.conv_x_h, st.conv_x_h)
hh, st_conv_h_h = l.conv_h_h(g, r .* h, ps.conv_h_h, st.conv_h_h)
h̃ = xh .+ hh
h̃ = NNlib.tanh_fast(h)
h = (1 .- z).* h̃ + z.* h
return (h, h), (conv_x_r = st_conv_xr, conv_h_r = st_conv_hr, conv_x_z = st_conv_x_z, conv_h_z = st_conv_h_z, conv_x_h = st_conv_x_h, conv_h_h = st_conv_h_h)
end

function Base.show(io::IO, l::GConvGRUCell)
print(io, "GConvGRUCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::GConvGRUCell) = (l.out_dims,)

GConvGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvGRUCell(ch, k; kwargs...))

@concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}
in_dims::Int
out_dims::Int
k::Int
conv_x_i
conv_h_i
dense_i
conv_x_f
conv_h_f
dense_f
conv_x_c
conv_h_c
dense_c
conv_x_o
conv_h_o
dense_o
init_state::Function
end

function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
#input gate
conv_x_i = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_i = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_i = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#forget gate
conv_x_f = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_f = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_f = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#cell gate
conv_x_c = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_c = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_c = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#output gate
conv_x_o = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_o = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_o = Dense(out_dims, 1; use_bias, init_weight, init_bias)
return GConvLSTMCell(in_dims, out_dims, k, conv_x_i, conv_h_i, dense_i, conv_x_f, conv_h_f, dense_f, conv_x_c, conv_h_c, dense_c, conv_x_o, conv_h_o, dense_o, init_state)
end

function (l::GConvLSTMCell)(g, (x, m), ps, st)
if m === nothing
h = l.init_state(l.out_dims, g.num_nodes)
c = l.init_state(l.out_dims, g.num_nodes)
else
h, c = m
end

dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
dense_c = StatefulLuxLayer{true}(l.dense_c, ps.dense_c, _getstate(st, :dense_c))
dense_o = StatefulLuxLayer{true}(l.dense_o, ps.dense_o, _getstate(st, :dense_o))

xi, st_conv_x_i = l.conv_x_i(g, x, ps.conv_x_i, st.conv_x_i)
hi, st_conv_h_i = l.conv_h_i(g, h, ps.conv_h_i, st.conv_h_i)
i = xi .+ hi .+ dense_i(c)
i = NNlib.sigmoid_fast(i)

xf, st_conv_x_f = l.conv_x_f(g, x, ps.conv_x_f, st.conv_x_f)
hf, st_conv_h_f = l.conv_h_f(g, h, ps.conv_h_f, st.conv_h_f)
f = xf .+ hf .+ dense_f(c)
f = NNlib.sigmoid_fast(f)

xc, st_conv_x_c = l.conv_x_c(g, x, ps.conv_x_c, st.conv_x_c)
hc, st_conv_h_c = l.conv_h_c(g, h, ps.conv_h_c, st.conv_h_c)
c = f .* c + i.* NNlib.tanh_fast(xc .+ hc .+ dense_c(c))

xo, st_conv_x_o = l.conv_x_o(g, x, ps.conv_x_o, st.conv_x_o)
ho, st_conv_h_o = l.conv_h_o(g, h, ps.conv_h_o, st.conv_h_o)
o = xo .+ ho .+ dense_o(c)
o = NNlib.sigmoid_fast(o)
h = o.* NNlib.tanh_fast(c)
return (h, (h, c)), (conv_x_i = st_conv_x_i, conv_h_i = st_conv_h_i, conv_x_f = st_conv_x_f, conv_h_f = st_conv_h_f, conv_x_c = st_conv_x_c, conv_h_c = st_conv_h_c, conv_x_o = st_conv_x_o, conv_h_o = st_conv_h_o)
end

function Base.show(io::IO, l::GConvLSTMCell)
print(io, "GConvLSTMCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::GConvLSTMCell) = (l.out_dims,)

GConvLSTM(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvLSTMCell(ch, k; kwargs...))

@concrete struct DCGRUCell <: GNNContainerLayer{(:dconv_u, :dconv_r, :dconv_c)}
in_dims::Int
out_dims::Int
k::Int
dconv_u
dconv_r
dconv_c
init_state::Function
end

function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
dconv_u = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
dconv_r = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
dconv_c = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
end

function (l::DCGRUCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, g.num_nodes)
end
h̃ = vcat(x, h)
z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
z = NNlib.sigmoid_fast.(z)
r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
r = NNlib.sigmoid_fast.(r)
ĥ = vcat(x, h .* r)
c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
c = NNlib.tanh_fast.(c)
h = z.* h + (1 .- z).* c
return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
end

function Base.show(io::IO, l::DCGRUCell)
print(io, "DCGRUCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)

DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))

24 changes: 24 additions & 0 deletions GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,28 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "GConvGRU" begin
l = GConvGRU(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "GConvLSTM" begin
l = GConvLSTM(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "DCGRU" begin
l = DCGRU(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end
Loading