@@ -421,20 +421,32 @@ function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = gl
421
421
return DCGRUCell (in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
422
422
end
423
423
424
- function (l:: DCGRUCell )(g, (x, h), ps, st)
425
- if h === nothing
426
- h = l. init_state (l. out_dims, g. num_nodes)
427
- end
428
- h̃ = vcat (x, h)
429
- z, st_dconv_u = l. dconv_u (g, h̃, ps. dconv_u, st. dconv_u)
430
- z = NNlib. sigmoid_fast .(z)
431
- r, st_dconv_r = l. dconv_r (g, h̃, ps. dconv_r, st. dconv_r)
432
- r = NNlib. sigmoid_fast .(r)
433
- ĥ = vcat (x, h .* r)
434
- c, st_dconv_c = l. dconv_c (g, ĥ, ps. dconv_c, st. dconv_c)
435
- c = NNlib. tanh_fast .(c)
436
- h = z.* h + (1 .- z). * c
437
- return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
424
+ # function (l::DCGRUCell)(g, (x, h), ps, st)
425
+ # if h === nothing
426
+ # h = l.init_state(l.out_dims, g.num_nodes)
427
+ # end
428
+ # h̃ = vcat(x, h)
429
+ # z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
430
+ # z = NNlib.sigmoid_fast.(z)
431
+ # r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
432
+ # r = NNlib.sigmoid_fast.(r)
433
+ # ĥ = vcat(x, h .* r)
434
+ # c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
435
+ # c = NNlib.tanh_fast.(c)
436
+ # h = z.* h + (1 .- z).* c
437
+ # return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
438
+ # end
439
+
440
+
441
+ function (l:: DCGRUCell )(g, x:: AbstractMatrix , ps, st)
442
+ h = l. init_state (l. out_dims, g. num_nodes)
443
+ return l (g, (x, (h,)), ps, st)
444
+ end
445
+
446
+ function (l:: DCGRUCell )(g, (x, (h,)):: Tuple , ps, st)
447
+ m = StatefulLuxLayer {true} (l, ps, st)
448
+ h, _ = dcgrucell_frwd (m, g, x, h)
449
+ return (h, (h,)), _getstate (m)
438
450
end
439
451
440
452
function Base. show (io:: IO , l:: DCGRUCell )
0 commit comments