Skip to content

Commit 3617034

Browse files
committed
Rename dot -> matmul
1 parent 1e04889 commit 3617034

File tree

13 files changed

+106
-105
lines changed

13 files changed

+106
-105
lines changed

layer/affine.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ func (l *Affine) String() string {
2424

2525
func (l *Affine) Forward(x, _ matrix.Matrix, _ ...Opts) matrix.Matrix {
2626
l.x = x
27-
return matrix.Dot(l.x, l.W).Add(l.B) // x.W + B
27+
return matrix.MatMul(l.x, l.W).Add(l.B) // x.W + B
2828
}
2929

3030
func (l *Affine) Backward(dout matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
31-
dx := matrix.Dot(dout, l.W.T())
32-
l.DW = matrix.Dot(l.x.T(), dout)
31+
dx := matrix.MatMul(dout, l.W.T())
32+
l.DW = matrix.MatMul(l.x.T(), dout)
3333
l.DB = matrix.New(dout.SumAxis0()) // Adjusting the shape
3434
return dx, nil
3535
}

layer/dot.go

Lines changed: 0 additions & 34 deletions
This file was deleted.

layer/gru.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ func (l *GRU) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
3434
Whz, Whr, Whh := WhH[0], WhH[1], WhH[2] // (H, H)
3535
Bz, Br, Bh := BH[0], BH[1], BH[2] // (1, H)
3636

37-
l.z = matrix.F(matrix.Dot(x, Wxz).Add(matrix.Dot(h, Whz)).Add(Bz), activation.Sigmoid) // z = sigmoid(x.Wxz + h.Whz + bz)
38-
l.r = matrix.F(matrix.Dot(x, Wxr).Add(matrix.Dot(h, Whr)).Add(Br), activation.Sigmoid) // r = sigmoid(x.Wxr + h.Whr + br)
39-
l.hhat = matrix.F(matrix.Dot(x, Wxh).Add(matrix.Dot(h.Mul(l.r), Whh)).Add(Bh), activation.Tanh) // hhat = tanh(x.Wxh + (h * r).Whh + bh)
37+
l.z = matrix.F(matrix.MatMul(x, Wxz).Add(matrix.MatMul(h, Whz)).Add(Bz), activation.Sigmoid) // z = sigmoid(x.Wxz + h.Whz + bz)
38+
l.r = matrix.F(matrix.MatMul(x, Wxr).Add(matrix.MatMul(h, Whr)).Add(Br), activation.Sigmoid) // r = sigmoid(x.Wxr + h.Whr + br)
39+
l.hhat = matrix.F(matrix.MatMul(x, Wxh).Add(matrix.MatMul(h.Mul(l.r), Whh)).Add(Bh), activation.Tanh) // hhat = tanh(x.Wxh + (h * r).Whh + bh)
4040
l.x, l.hprev = x, h
4141

4242
hnext := matrix.SubC(1, l.z).Mul(l.hprev).Add(l.z.Mul(l.hhat)) // (1 - z) * hprev + z * hhat
@@ -56,31 +56,31 @@ func (l *GRU) Backward(dhnext matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
5656
dhprev := dhnext.Mul(matrix.SubC(1, l.z)) // dhprev = dhnext * (1 - z)
5757

5858
// tanh
59-
dt := dhhat.Mul(matrix.F(l.hhat, dTanh)) // dt = dhhat * (1 - hhat**2)
60-
dbh := matrix.New(dt.SumAxis0()) // dbh = sum(dt, axis=0)
61-
dWhh := matrix.Dot(l.r.Mul(l.hprev).T(), dt) // dWhh = (r * hprev).T.dt
62-
dhr := matrix.Dot(dt, Whh.T()) // dhr = dt.Whh.T
63-
dWxh := matrix.Dot(l.x.T(), dt) // dWxh = x.T.dt
64-
dx := matrix.Dot(dt, Wxh.T()) // dx = dt.Wxh.T
65-
dhprev = dhprev.Add(dhr.Mul(l.r)) // dhprev = dhprev + dhr * r
59+
dt := dhhat.Mul(matrix.F(l.hhat, dTanh)) // dt = dhhat * (1 - hhat**2)
60+
dbh := matrix.New(dt.SumAxis0()) // dbh = sum(dt, axis=0)
61+
dWhh := matrix.MatMul(l.r.Mul(l.hprev).T(), dt) // dWhh = (r * hprev).T.dt
62+
dhr := matrix.MatMul(dt, Whh.T()) // dhr = dt.Whh.T
63+
dWxh := matrix.MatMul(l.x.T(), dt) // dWxh = x.T.dt
64+
dx := matrix.MatMul(dt, Wxh.T()) // dx = dt.Wxh.T
65+
dhprev = dhprev.Add(dhr.Mul(l.r)) // dhprev = dhprev + dhr * r
6666

6767
// gate(z)
6868
dz := dhnext.Mul(l.hhat).Sub(dhnext.Mul(l.hprev)) // dz = dhnext * hhat - dhnext * hprev
6969
dtz := dz.Mul(matrix.F(l.z, dSigmoid)) // dtz = dz * z * (1 - z)
7070
dbz := matrix.New(dtz.SumAxis0()) // dbz = sum(dtz, axis=0)
71-
dWhz := matrix.Dot(l.hprev.T(), dtz) // dWhz = hprev.T.dtz
72-
dhprev = dhprev.Add(matrix.Dot(dtz, Whz.T())) // dhprev = dhprev + dtz.Whz.T
73-
dWxz := matrix.Dot(l.x.T(), dtz) // dWxz = x.T.dtz
74-
dx = dx.Add(matrix.Dot(dt, Wxz.T())) // dx = dx + dtz.Wxz.T
71+
dWhz := matrix.MatMul(l.hprev.T(), dtz) // dWhz = hprev.T.dtz
72+
dhprev = dhprev.Add(matrix.MatMul(dtz, Whz.T())) // dhprev = dhprev + dtz.Whz.T
73+
dWxz := matrix.MatMul(l.x.T(), dtz) // dWxz = x.T.dtz
74+
dx = dx.Add(matrix.MatMul(dt, Wxz.T())) // dx = dx + dtz.Wxz.T
7575

7676
// gate(r)
77-
dr := dhr.Mul(l.hprev) // dr = dhr * hprev
78-
dtr := dr.Mul(matrix.F(l.r, dSigmoid)) // dtr = dr * r * (1 - r)
79-
dbr := matrix.New(dtr.SumAxis0()) // dbr = sum(dtr, axis=0)
80-
dWhr := matrix.Dot(l.hprev.T(), dtr) // dWhr = hprev.T.dtr
81-
dhprev = dhprev.Add(matrix.Dot(dtr, Whr.T())) // dhprev = dhprev + dtr.Whr.T
82-
dWxr := matrix.Dot(l.x.T(), dtr) // dWzr = x.T.dtr
83-
dx = dx.Add(matrix.Dot(dtr, Wxr.T())) // dx = dx + dtr.Wxr.T
77+
dr := dhr.Mul(l.hprev) // dr = dhr * hprev
78+
dtr := dr.Mul(matrix.F(l.r, dSigmoid)) // dtr = dr * r * (1 - r)
79+
dbr := matrix.New(dtr.SumAxis0()) // dbr = sum(dtr, axis=0)
80+
dWhr := matrix.MatMul(l.hprev.T(), dtr) // dWhr = hprev.T.dtr
81+
dhprev = dhprev.Add(matrix.MatMul(dtr, Whr.T())) // dhprev = dhprev + dtr.Whr.T
82+
dWxr := matrix.MatMul(l.x.T(), dtr) // dWzr = x.T.dtr
83+
dx = dx.Add(matrix.MatMul(dtr, Wxr.T())) // dx = dx + dtr.Wxr.T
8484

8585
// grads
8686
l.DWx = matrix.HStack(dWxz, dWxr, dWxh)

layer/lstm.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ func (l *LSTM) String() string {
2626
}
2727

2828
func (l *LSTM) Forward(x, h, c matrix.Matrix, _ ...Opts) (matrix.Matrix, matrix.Matrix) {
29-
A := matrix.Dot(x, l.Wx).Add(matrix.Dot(h, l.Wh)).Add(l.B) // (N, 4H) = x(N, D).Wx(D, 4H) + h(N, H).Wh(H, 4H) + b(1, 4H)
30-
AH := matrix.Split(A, len(h[0])) // (4, N, H)
29+
A := matrix.MatMul(x, l.Wx).Add(matrix.MatMul(h, l.Wh)).Add(l.B) // (N, 4H) = x(N, D).Wx(D, 4H) + h(N, H).Wh(H, 4H) + b(1, 4H)
30+
AH := matrix.Split(A, len(h[0])) // (4, N, H)
3131

3232
f := matrix.F(AH[0], activation.Sigmoid) // (N, H)
3333
g := matrix.F(AH[1], activation.Tanh) // (N, H)
@@ -64,14 +64,14 @@ func (l *LSTM) Backward(dhNext, dcNext matrix.Matrix) (matrix.Matrix, matrix.Mat
6464
dA := matrix.HStack(df, dg, di, do) // (N, 4H)
6565

6666
// grads
67-
l.DWx = matrix.Dot(l.x.T(), dA)
68-
l.DWh = matrix.Dot(l.h.T(), dA)
67+
l.DWx = matrix.MatMul(l.x.T(), dA)
68+
l.DWh = matrix.MatMul(l.h.T(), dA)
6969
l.DB = matrix.New(dA.SumAxis0())
7070

7171
// prev
72-
dx := matrix.Dot(dA, l.Wx.T()) // (N, D)
73-
dhPrev := matrix.Dot(dA, l.Wh.T()) // (N, H)
74-
dcPrev := ds.Mul(l.f) // (N, H)
72+
dx := matrix.MatMul(dA, l.Wx.T()) // (N, D)
73+
dhPrev := matrix.MatMul(dA, l.Wh.T()) // (N, H)
74+
dcPrev := ds.Mul(l.f) // (N, H)
7575

7676
return dx, dhPrev, dcPrev
7777
}

layer/matmul.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package layer
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/itsubaki/neu/math/matrix"
7+
)
8+
9+
// MatMul is a layer that performs a matrix product.
10+
type MatMul struct {
11+
W matrix.Matrix // params
12+
DW matrix.Matrix // grads
13+
x matrix.Matrix
14+
}
15+
16+
func (l *MatMul) Params() []matrix.Matrix { return []matrix.Matrix{l.W} }
17+
func (l *MatMul) Grads() []matrix.Matrix { return []matrix.Matrix{l.DW} }
18+
func (l *MatMul) SetParams(p ...matrix.Matrix) { l.W = p[0] }
19+
func (l *MatMul) String() string {
20+
a, b := l.W.Dim()
21+
return fmt.Sprintf("%T: W(%v, %v): %v", l, a, b, a*b)
22+
}
23+
24+
func (l *MatMul) Forward(x, _ matrix.Matrix, _ ...Opts) matrix.Matrix {
25+
l.x = x
26+
return matrix.MatMul(l.x, l.W)
27+
}
28+
29+
func (l *MatMul) Backward(dout matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
30+
dx := matrix.MatMul(dout, l.W.T())
31+
dW := matrix.MatMul(l.x.T(), dout)
32+
l.DW = dW
33+
return dx, nil
34+
}
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ import (
77
"github.com/itsubaki/neu/math/matrix"
88
)
99

10-
func ExampleDot() {
10+
func ExampleMatMul() {
1111
W := matrix.New(
1212
[]float64{5, 6},
1313
[]float64{7, 8},
1414
)
15-
dot := &layer.Dot{W: W}
16-
fmt.Println(dot)
15+
matmul := &layer.MatMul{W: W}
16+
fmt.Println(matmul)
1717
fmt.Println()
1818

1919
// forward
@@ -22,20 +22,20 @@ func ExampleDot() {
2222
[]float64{3, 4},
2323
)
2424

25-
for _, r := range dot.Forward(A, nil) {
25+
for _, r := range matmul.Forward(A, nil) {
2626
fmt.Println(r)
2727
}
2828
fmt.Println()
2929

3030
// backward
31-
dx, _ := dot.Backward(matrix.New([]float64{1, 0}, []float64{0, 1}))
31+
dx, _ := matmul.Backward(matrix.New([]float64{1, 0}, []float64{0, 1}))
3232
for _, r := range dx {
3333
fmt.Println(r)
3434
}
3535
fmt.Println()
3636

3737
// Output:
38-
// *layer.Dot: W(2, 2): 4
38+
// *layer.MatMul: W(2, 2): 4
3939
//
4040
// [19 22]
4141
// [43 50]
@@ -46,12 +46,12 @@ func ExampleDot() {
4646

4747
}
4848

49-
func ExampleDot_Params() {
50-
dot := &layer.Dot{}
51-
dot.SetParams(make([]matrix.Matrix, 1)...)
49+
func ExampleMatMul_Params() {
50+
matmul := &layer.MatMul{}
51+
matmul.SetParams(make([]matrix.Matrix, 1)...)
5252

53-
fmt.Println(dot.Params())
54-
fmt.Println(dot.Grads())
53+
fmt.Println(matmul.Params())
54+
fmt.Println(matmul.Grads())
5555

5656
// Output:
5757
// [[]]

layer/rnn.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func (l *RNN) String() string {
2424
}
2525

2626
func (l *RNN) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
27-
t := matrix.Dot(h, l.Wh).Add(matrix.Dot(x, l.Wx)).Add(l.B) // h(N, H).Wh(H, H) + x(N, D).Wx(D, H) + b(1, H)
27+
t := matrix.MatMul(h, l.Wh).Add(matrix.MatMul(x, l.Wx)).Add(l.B) // h(N, H).Wh(H, H) + x(N, D).Wx(D, H) + b(1, H)
2828
hNext := matrix.F(t, activation.Tanh)
2929

3030
l.x, l.h, l.hNext = x, h, hNext // cache
@@ -33,12 +33,12 @@ func (l *RNN) Forward(x, h matrix.Matrix, _ ...Opts) matrix.Matrix {
3333

3434
func (l *RNN) Backward(dhNext matrix.Matrix) (matrix.Matrix, matrix.Matrix) {
3535
dt := dhNext.Mul(matrix.F(l.hNext, dTanh)) // dt = dhNext * (1 - hNext**2)
36-
dx := matrix.Dot(dt, l.Wx.T()) // dot(dt(N, H), Wx.T(H, D)) -> dx(N, D)
37-
dh := matrix.Dot(dt, l.Wh.T()) // dot(dt(N, H), Wh.T(H, H)) -> dh(N, H)
36+
dx := matrix.MatMul(dt, l.Wx.T()) // dot(dt(N, H), Wx.T(H, D)) -> dx(N, D)
37+
dh := matrix.MatMul(dt, l.Wh.T()) // dot(dt(N, H), Wh.T(H, H)) -> dh(N, H)
3838

39-
l.DWx = matrix.Dot(l.x.T(), dt) // dot(x.T(D, N), dt(N, H)) -> (D, H)
40-
l.DWh = matrix.Dot(l.h.T(), dt) // dot(hPrev.T(H, N), dt(N, H)) -> (H, H)
41-
l.DB = matrix.New(dt.SumAxis0()) // sum(dt(N, H), axis=0) -> (1, H)
39+
l.DWx = matrix.MatMul(l.x.T(), dt) // dot(x.T(D, N), dt(N, H)) -> (D, H)
40+
l.DWh = matrix.MatMul(l.h.T(), dt) // dot(hPrev.T(H, N), dt(N, H)) -> (H, H)
41+
l.DB = matrix.New(dt.SumAxis0()) // sum(dt(N, H), axis=0) -> (1, H)
4242
return dx, dh
4343
}
4444

math/matrix/matrix.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,16 +339,17 @@ func SubC(c float64, m Matrix) Matrix {
339339
return F(m, func(v float64) float64 { return c - v })
340340
}
341341

342-
// Dot returns the dot product of m and n.
343-
func Dot(m, n Matrix) Matrix {
342+
// MatMul returns the matrix product of m and n.
343+
func MatMul(m, n Matrix) Matrix {
344344
a, b := m.Dim()
345345
_, p := n.Dim()
346346

347347
out := Zero(a, p)
348-
for i := 0; i < a; i++ {
349-
for j := 0; j < p; j++ {
350-
for k := 0; k < b; k++ {
351-
out[i][j] = out[i][j] + m[i][k]*n[k][j]
348+
for i := range a {
349+
for k := range b {
350+
mik := m[i][k]
351+
for j := range p {
352+
out[i][j] += mik * n[k][j]
352353
}
353354
}
354355
}

math/matrix/matrix_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func ExampleInt() {
143143
// [[1 2 3]]
144144
}
145145

146-
func ExampleDot() {
146+
func ExampleMatMul() {
147147
A := matrix.New(
148148
[]float64{1, 2},
149149
[]float64{3, 4},
@@ -154,7 +154,7 @@ func ExampleDot() {
154154
[]float64{7, 8},
155155
)
156156

157-
for _, r := range matrix.Dot(A, B) {
157+
for _, r := range matrix.MatMul(A, B) {
158158
fmt.Println(r)
159159
}
160160

model/cbow.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ func NewCBOW(c *CBOWConfig, s ...randv2.Source) *CBOW {
3030

3131
// model
3232
return &CBOW{
33-
Win0: &layer.Dot{W: matrix.Randn(V, H, s[0]).MulC(0.01)},
34-
Win1: &layer.Dot{W: matrix.Randn(V, H, s[0]).MulC(0.01)},
35-
Wout: &layer.Dot{W: matrix.Randn(H, V, s[0]).MulC(0.01)},
33+
Win0: &layer.MatMul{W: matrix.Randn(V, H, s[0]).MulC(0.01)},
34+
Win1: &layer.MatMul{W: matrix.Randn(V, H, s[0]).MulC(0.01)},
35+
Wout: &layer.MatMul{W: matrix.Randn(H, V, s[0]).MulC(0.01)},
3636
Loss: &layer.SoftmaxWithLoss{},
3737
Source: s[0],
3838
}

0 commit comments

Comments
 (0)