Skip to content

Commit bd5e2f2

Browse files
authored
[GNNLux] Add GConvLSTM, GConvGRU and DCGRU temporal layers (#487)
* Add exports * Add GConvGRU, GConvLSTM and DCGRU * Add GConvGRU, GConvLSTM and DCGRU tests
1 parent 82a7450 commit bd5e2f2

File tree

3 files changed

+210
-2
lines changed

3 files changed

+210
-2
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,11 @@ export AGNNConv,
4040
# TransformerConv
4141

4242
include("layers/temporalconv.jl")
43-
export TGCN
44-
export A3TGCN
43+
export TGCN,
44+
A3TGCN,
45+
GConvGRU,
46+
GConvLSTM,
47+
DCGRU
4548

4649
end #module
4750

GNNLux/src/layers/temporalconv.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,184 @@ LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)
9393
function Base.show(io::IO, l::A3TGCN)
9494
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
9595
end
96+
97+
@concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
98+
in_dims::Int
99+
out_dims::Int
100+
k::Int
101+
conv_x_r
102+
conv_h_r
103+
conv_x_z
104+
conv_h_z
105+
conv_x_h
106+
conv_h_h
107+
init_state::Function
108+
end
109+
110+
function GConvGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
111+
in_dims, out_dims = ch
112+
#reset gate
113+
conv_x_r = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
114+
conv_h_r = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
115+
#update gate
116+
conv_x_z = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
117+
conv_h_z = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
118+
#hidden state
119+
conv_x_h = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
120+
conv_h_h = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
121+
return GConvGRUCell(in_dims, out_dims, k, conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, init_state)
122+
end
123+
124+
function (l::GConvGRUCell)(g, (x, h), ps, st)
125+
if h === nothing
126+
h = l.init_state(l.out_dims, g.num_nodes)
127+
end
128+
xr, st_conv_xr = l.conv_x_r(g, x, ps.conv_x_r, st.conv_x_r)
129+
hr, st_conv_hr = l.conv_h_r(g, h, ps.conv_h_r, st.conv_h_r)
130+
r = xr .+ hr
131+
r = NNlib.sigmoid_fast(r)
132+
xz, st_conv_x_z = l.conv_x_z(g, x, ps.conv_x_z, st.conv_x_z)
133+
hz, st_conv_h_z = l.conv_h_z(g, h, ps.conv_h_z, st.conv_h_z)
134+
z = xz .+ hz
135+
z = NNlib.sigmoid_fast(z)
136+
xh, st_conv_x_h = l.conv_x_h(g, x, ps.conv_x_h, st.conv_x_h)
137+
hh, st_conv_h_h = l.conv_h_h(g, r .* h, ps.conv_h_h, st.conv_h_h)
138+
= xh .+ hh
139+
= NNlib.tanh_fast(h)
140+
h = (1 .- z).*+ z.* h
141+
return (h, h), (conv_x_r = st_conv_xr, conv_h_r = st_conv_hr, conv_x_z = st_conv_x_z, conv_h_z = st_conv_h_z, conv_x_h = st_conv_x_h, conv_h_h = st_conv_h_h)
142+
end
143+
144+
function Base.show(io::IO, l::GConvGRUCell)
145+
print(io, "GConvGRUCell($(l.in_dims) => $(l.out_dims))")
146+
end
147+
148+
LuxCore.outputsize(l::GConvGRUCell) = (l.out_dims,)
149+
150+
GConvGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvGRUCell(ch, k; kwargs...))
151+
152+
@concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}
153+
in_dims::Int
154+
out_dims::Int
155+
k::Int
156+
conv_x_i
157+
conv_h_i
158+
dense_i
159+
conv_x_f
160+
conv_h_f
161+
dense_f
162+
conv_x_c
163+
conv_h_c
164+
dense_c
165+
conv_x_o
166+
conv_h_o
167+
dense_o
168+
init_state::Function
169+
end
170+
171+
function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
172+
in_dims, out_dims = ch
173+
#input gate
174+
conv_x_i = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
175+
conv_h_i = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
176+
dense_i = Dense(out_dims, 1; use_bias, init_weight, init_bias)
177+
#forget gate
178+
conv_x_f = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
179+
conv_h_f = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
180+
dense_f = Dense(out_dims, 1; use_bias, init_weight, init_bias)
181+
#cell gate
182+
conv_x_c = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
183+
conv_h_c = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
184+
dense_c = Dense(out_dims, 1; use_bias, init_weight, init_bias)
185+
#output gate
186+
conv_x_o = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
187+
conv_h_o = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
188+
dense_o = Dense(out_dims, 1; use_bias, init_weight, init_bias)
189+
return GConvLSTMCell(in_dims, out_dims, k, conv_x_i, conv_h_i, dense_i, conv_x_f, conv_h_f, dense_f, conv_x_c, conv_h_c, dense_c, conv_x_o, conv_h_o, dense_o, init_state)
190+
end
191+
192+
function (l::GConvLSTMCell)(g, (x, m), ps, st)
193+
if m === nothing
194+
h = l.init_state(l.out_dims, g.num_nodes)
195+
c = l.init_state(l.out_dims, g.num_nodes)
196+
else
197+
h, c = m
198+
end
199+
200+
dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
201+
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
202+
dense_c = StatefulLuxLayer{true}(l.dense_c, ps.dense_c, _getstate(st, :dense_c))
203+
dense_o = StatefulLuxLayer{true}(l.dense_o, ps.dense_o, _getstate(st, :dense_o))
204+
205+
xi, st_conv_x_i = l.conv_x_i(g, x, ps.conv_x_i, st.conv_x_i)
206+
hi, st_conv_h_i = l.conv_h_i(g, h, ps.conv_h_i, st.conv_h_i)
207+
i = xi .+ hi .+ dense_i(c)
208+
i = NNlib.sigmoid_fast(i)
209+
210+
xf, st_conv_x_f = l.conv_x_f(g, x, ps.conv_x_f, st.conv_x_f)
211+
hf, st_conv_h_f = l.conv_h_f(g, h, ps.conv_h_f, st.conv_h_f)
212+
f = xf .+ hf .+ dense_f(c)
213+
f = NNlib.sigmoid_fast(f)
214+
215+
xc, st_conv_x_c = l.conv_x_c(g, x, ps.conv_x_c, st.conv_x_c)
216+
hc, st_conv_h_c = l.conv_h_c(g, h, ps.conv_h_c, st.conv_h_c)
217+
c = f .* c + i.* NNlib.tanh_fast(xc .+ hc .+ dense_c(c))
218+
219+
xo, st_conv_x_o = l.conv_x_o(g, x, ps.conv_x_o, st.conv_x_o)
220+
ho, st_conv_h_o = l.conv_h_o(g, h, ps.conv_h_o, st.conv_h_o)
221+
o = xo .+ ho .+ dense_o(c)
222+
o = NNlib.sigmoid_fast(o)
223+
h = o.* NNlib.tanh_fast(c)
224+
return (h, (h, c)), (conv_x_i = st_conv_x_i, conv_h_i = st_conv_h_i, conv_x_f = st_conv_x_f, conv_h_f = st_conv_h_f, conv_x_c = st_conv_x_c, conv_h_c = st_conv_h_c, conv_x_o = st_conv_x_o, conv_h_o = st_conv_h_o)
225+
end
226+
227+
function Base.show(io::IO, l::GConvLSTMCell)
228+
print(io, "GConvLSTMCell($(l.in_dims) => $(l.out_dims))")
229+
end
230+
231+
LuxCore.outputsize(l::GConvLSTMCell) = (l.out_dims,)
232+
233+
GConvLSTM(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvLSTMCell(ch, k; kwargs...))
234+
235+
@concrete struct DCGRUCell <: GNNContainerLayer{(:dconv_u, :dconv_r, :dconv_c)}
236+
in_dims::Int
237+
out_dims::Int
238+
k::Int
239+
dconv_u
240+
dconv_r
241+
dconv_c
242+
init_state::Function
243+
end
244+
245+
function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
246+
in_dims, out_dims = ch
247+
dconv_u = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
248+
dconv_r = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
249+
dconv_c = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
250+
return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
251+
end
252+
253+
function (l::DCGRUCell)(g, (x, h), ps, st)
254+
if h === nothing
255+
h = l.init_state(l.out_dims, g.num_nodes)
256+
end
257+
= vcat(x, h)
258+
z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
259+
z = NNlib.sigmoid_fast.(z)
260+
r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
261+
r = NNlib.sigmoid_fast.(r)
262+
= vcat(x, h .* r)
263+
c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
264+
c = NNlib.tanh_fast.(c)
265+
h = z.* h + (1 .- z).* c
266+
return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
267+
end
268+
269+
function Base.show(io::IO, l::DCGRUCell)
270+
print(io, "DCGRUCell($(l.in_dims) => $(l.out_dims))")
271+
end
272+
273+
LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)
274+
275+
DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))
276+

GNNLux/test/layers/temporalconv_test.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,28 @@
2020
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
2121
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
2222
end
23+
24+
@testset "GConvGRU" begin
25+
l = GConvGRU(3=>3, 2)
26+
ps = LuxCore.initialparameters(rng, l)
27+
st = LuxCore.initialstates(rng, l)
28+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
29+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
30+
end
31+
32+
@testset "GConvLSTM" begin
33+
l = GConvLSTM(3=>3, 2)
34+
ps = LuxCore.initialparameters(rng, l)
35+
st = LuxCore.initialstates(rng, l)
36+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
37+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
38+
end
39+
40+
@testset "DCGRU" begin
41+
l = DCGRU(3=>3, 2)
42+
ps = LuxCore.initialparameters(rng, l)
43+
st = LuxCore.initialstates(rng, l)
44+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
45+
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
46+
end
2347
end

0 commit comments

Comments
 (0)