Skip to content

Commit e28f65e

Browse files
DCGruCell
1 parent d3cbab9 commit e28f65e

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

GNNLux/src/layers/temporalconv.jl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -421,20 +421,32 @@ function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = gl
421421
return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
422422
end
423423

424-
function (l::DCGRUCell)(g, (x, h), ps, st)
425-
if h === nothing
426-
h = l.init_state(l.out_dims, g.num_nodes)
427-
end
428-
= vcat(x, h)
429-
z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
430-
z = NNlib.sigmoid_fast.(z)
431-
r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
432-
r = NNlib.sigmoid_fast.(r)
433-
= vcat(x, h .* r)
434-
c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
435-
c = NNlib.tanh_fast.(c)
436-
h = z.* h + (1 .- z).* c
437-
return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
424+
# function (l::DCGRUCell)(g, (x, h), ps, st)
425+
# if h === nothing
426+
# h = l.init_state(l.out_dims, g.num_nodes)
427+
# end
428+
# h̃ = vcat(x, h)
429+
# z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
430+
# z = NNlib.sigmoid_fast.(z)
431+
# r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
432+
# r = NNlib.sigmoid_fast.(r)
433+
# ĥ = vcat(x, h .* r)
434+
# c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
435+
# c = NNlib.tanh_fast.(c)
436+
# h = z.* h + (1 .- z).* c
437+
# return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
438+
# end
439+
440+
441+
function (l::DCGRUCell)(g, x::AbstractMatrix, ps, st)
442+
h = l.init_state(l.out_dims, g.num_nodes)
443+
return l(g, (x, (h,)), ps, st)
444+
end
445+
446+
function (l::DCGRUCell)(g, (x, (h,))::Tuple, ps, st)
447+
m = StatefulLuxLayer{true}(l, ps, st)
448+
h, _ = dcgrucell_frwd(m, g, x, h)
449+
return (h, (h,)), _getstate(m)
438450
end
439451

440452
function Base.show(io::IO, l::DCGRUCell)

0 commit comments

Comments
 (0)