@@ -34,9 +34,9 @@ func (l *GRU) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
3434 Whz , Whr , Whh := WhH [0 ], WhH [1 ], WhH [2 ] // (H, H)
3535 Bz , Br , Bh := BH [0 ], BH [1 ], BH [2 ] // (1, H)
3636
37- l .z = matrix .F (matrix .Dot (x , Wxz ).Add (matrix .Dot (h , Whz )).Add (Bz ), activation .Sigmoid ) // z = sigmoid(x.Wxz + h.Whz + bz)
38- l .r = matrix .F (matrix .Dot (x , Wxr ).Add (matrix .Dot (h , Whr )).Add (Br ), activation .Sigmoid ) // r = sigmoid(x.Wxr + h.Whr + br)
39- l .hhat = matrix .F (matrix .Dot (x , Wxh ).Add (matrix .Dot (h .Mul (l .r ), Whh )).Add (Bh ), activation .Tanh ) // hhat = tanh(x.Wxh + (h * r).Whh + bh)
37+ l .z = matrix .F (matrix .MatMul (x , Wxz ).Add (matrix .MatMul (h , Whz )).Add (Bz ), activation .Sigmoid ) // z = sigmoid(x.Wxz + h.Whz + bz)
38+ l .r = matrix .F (matrix .MatMul (x , Wxr ).Add (matrix .MatMul (h , Whr )).Add (Br ), activation .Sigmoid ) // r = sigmoid(x.Wxr + h.Whr + br)
39+ l .hhat = matrix .F (matrix .MatMul (x , Wxh ).Add (matrix .MatMul (h .Mul (l .r ), Whh )).Add (Bh ), activation .Tanh ) // hhat = tanh(x.Wxh + (h * r).Whh + bh)
4040 l .x , l .hprev = x , h
4141
4242 hnext := matrix .SubC (1 , l .z ).Mul (l .hprev ).Add (l .z .Mul (l .hhat )) // (1 - z) * hprev + z * hhat
@@ -56,31 +56,31 @@ func (l *GRU) Backward(dhnext matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
5656 dhprev := dhnext .Mul (matrix .SubC (1 , l .z )) // dhprev = dhnext * (1 - z)
5757
5858 // tanh
59- dt := dhhat .Mul (matrix .F (l .hhat , dTanh )) // dt = dhhat * (1 - hhat**2)
60- dbh := matrix .New (dt .SumAxis0 ()) // dbh = sum(dt, axis=0)
61- dWhh := matrix .Dot (l .r .Mul (l .hprev ).T (), dt ) // dWhh = (r * hprev).T.dt
62- dhr := matrix .Dot (dt , Whh .T ()) // dhr = dt.Whh.T
63- dWxh := matrix .Dot (l .x .T (), dt ) // dWxh = x.T.dt
64- dx := matrix .Dot (dt , Wxh .T ()) // dx = dt.Wxh.T
65- dhprev = dhprev .Add (dhr .Mul (l .r )) // dhprev = dhprev + dhr * r
59+ dt := dhhat .Mul (matrix .F (l .hhat , dTanh )) // dt = dhhat * (1 - hhat**2)
60+ dbh := matrix .New (dt .SumAxis0 ()) // dbh = sum(dt, axis=0)
61+ dWhh := matrix .MatMul (l .r .Mul (l .hprev ).T (), dt ) // dWhh = (r * hprev).T.dt
62+ dhr := matrix .MatMul (dt , Whh .T ()) // dhr = dt.Whh.T
63+ dWxh := matrix .MatMul (l .x .T (), dt ) // dWxh = x.T.dt
64+ dx := matrix .MatMul (dt , Wxh .T ()) // dx = dt.Wxh.T
65+ dhprev = dhprev .Add (dhr .Mul (l .r )) // dhprev = dhprev + dhr * r
6666
6767 // gate(z)
6868 dz := dhnext .Mul (l .hhat ).Sub (dhnext .Mul (l .hprev )) // dz = dhnext * hhat - dhnext * hprev
6969 dtz := dz .Mul (matrix .F (l .z , dSigmoid )) // dtz = dz * z * (1 - z)
7070 dbz := matrix .New (dtz .SumAxis0 ()) // dbz = sum(dtz, axis=0)
71- dWhz := matrix .Dot (l .hprev .T (), dtz ) // dWhz = hprev.T.dtz
72- dhprev = dhprev .Add (matrix .Dot (dtz , Whz .T ())) // dhprev = dhprev + dtz.Whz.T
73- dWxz := matrix .Dot (l .x .T (), dtz ) // dWxz = x.T.dtz
74- dx = dx .Add (matrix .Dot (dt , Wxz .T ())) // dx = dx + dtz.Wxz.T
71+ dWhz := matrix .MatMul (l .hprev .T (), dtz ) // dWhz = hprev.T.dtz
72+ dhprev = dhprev .Add (matrix .MatMul (dtz , Whz .T ())) // dhprev = dhprev + dtz.Whz.T
73+ dWxz := matrix .MatMul (l .x .T (), dtz ) // dWxz = x.T.dtz
74+ dx = dx .Add (matrix .MatMul (dt , Wxz .T ())) // dx = dx + dtz.Wxz.T
7575
7676 // gate(r)
77- dr := dhr .Mul (l .hprev ) // dr = dhr * hprev
78- dtr := dr .Mul (matrix .F (l .r , dSigmoid )) // dtr = dr * r * (1 - r)
79- dbr := matrix .New (dtr .SumAxis0 ()) // dbr = sum(dtr, axis=0)
80- dWhr := matrix .Dot (l .hprev .T (), dtr ) // dWhr = hprev.T.dtr
81- dhprev = dhprev .Add (matrix .Dot (dtr , Whr .T ())) // dhprev = dhprev + dtr.Whr.T
82- dWxr := matrix .Dot (l .x .T (), dtr ) // dWzr = x.T.dtr
83- dx = dx .Add (matrix .Dot (dtr , Wxr .T ())) // dx = dx + dtr.Wxr.T
77+ dr := dhr .Mul (l .hprev ) // dr = dhr * hprev
78+ dtr := dr .Mul (matrix .F (l .r , dSigmoid )) // dtr = dr * r * (1 - r)
79+ dbr := matrix .New (dtr .SumAxis0 ()) // dbr = sum(dtr, axis=0)
80+ dWhr := matrix .MatMul (l .hprev .T (), dtr ) // dWhr = hprev.T.dtr
81+ dhprev = dhprev .Add (matrix .MatMul (dtr , Whr .T ())) // dhprev = dhprev + dtr.Whr.T
82+ dWxr := matrix .MatMul (l .x .T (), dtr ) // dWzr = x.T.dtr
83+ dx = dx .Add (matrix .MatMul (dtr , Wxr .T ())) // dx = dx + dtr.Wxr.T
8484
8585 // grads
8686 l .DWx = matrix .HStack (dWxz , dWxr , dWxh )
0 commit comments