@@ -421,20 +421,32 @@ function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = gl
421421    return  DCGRUCell (in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
422422end 
423423
424- function  (l:: DCGRUCell )(g, (x, h), ps, st)
425-     if  h ===  nothing 
426-         h =  l. init_state (l. out_dims, g. num_nodes)
427-     end 
428-     h̃ =  vcat (x, h)
429-     z, st_dconv_u =  l. dconv_u (g, h̃, ps. dconv_u, st. dconv_u)
430-     z =  NNlib. sigmoid_fast .(z)
431-     r, st_dconv_r =  l. dconv_r (g, h̃, ps. dconv_r, st. dconv_r)
432-     r =  NNlib. sigmoid_fast .(r)
433-     ĥ =  vcat (x, h .*  r)
434-     c, st_dconv_c =  l. dconv_c (g, ĥ, ps. dconv_c, st. dconv_c)
435-     c =  NNlib. tanh_fast .(c)
436-     h =  z.*  h +  (1  .-  z). *  c
437-     return  (h, h), (dconv_u =  st_dconv_u, dconv_r =  st_dconv_r, dconv_c =  st_dconv_c)
424+ #  function (l::DCGRUCell)(g, (x, h), ps, st)
425+ #      if h === nothing
426+ #          h = l.init_state(l.out_dims, g.num_nodes)
427+ #      end
428+ #      h̃ = vcat(x, h)
429+ #      z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
430+ #      z = NNlib.sigmoid_fast.(z)
431+ #      r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
432+ #      r = NNlib.sigmoid_fast.(r)
433+ #      ĥ = vcat(x, h .* r)
434+ #      c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
435+ #      c = NNlib.tanh_fast.(c)
436+ #      h = z.* h + (1 .- z).* c
437+ #      return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
438+ #  end
439+ 
440+ 
441+ function  (l:: DCGRUCell )(g, x:: AbstractMatrix , ps, st)
442+     h =  l. init_state (l. out_dims, g. num_nodes)
443+     return  l (g, (x, (h,)), ps, st)
444+ end 
445+ 
446+ function  (l:: DCGRUCell )(g, (x, (h,)):: Tuple , ps, st)
447+     m =  StatefulLuxLayer {true} (l, ps, st)
448+     h, _ =  dcgrucell_frwd (m, g, x, h)
449+     return  (h, (h,)), _getstate (m)
438450end 
439451
440452function  Base. show (io:: IO , l:: DCGRUCell )
0 commit comments