Skip to content

Commit 8f810ad

Browse files
committed
Update matrix operation
1 parent 1801a14 commit 8f810ad

File tree

4 files changed

+47
-41
lines changed

4 files changed

+47
-41
lines changed

math/matrix/matrix.go

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ func Zero(rows, cols int) Matrix {
4444
}
4545

4646
// Identity returns an identity matrix.
47-
func Identity(rows, cols int) Matrix {
48-
m := Zero(rows, cols)
49-
for i := range rows {
47+
func Identity(size int) Matrix {
48+
m := Zero(size, size)
49+
for i := range size {
5050
m.Set(i, i, 1)
5151
}
5252

@@ -68,6 +68,16 @@ func (m Matrix) Set(i, j int, v complex128) {
6868
m.Data[i*m.Cols+j] = v
6969
}
7070

71+
// AddAt adds a value of matrix at (i,j).
72+
func (m Matrix) AddAt(i, j int, v complex128) {
73+
m.Data[i*m.Cols+j] += v
74+
}
75+
76+
// MulAt multiplies a value of matrix at (i,j).
77+
func (m Matrix) MulAt(i, j int, v complex128) {
78+
m.Data[i*m.Cols+j] *= v
79+
}
80+
7181
// Seq2 returns a sequence of rows.
7282
func (m Matrix) Seq2() iter.Seq2[int, []complex128] {
7383
return func(yield func(int, []complex128) bool) {
@@ -169,7 +179,7 @@ func (m Matrix) IsUnitary(eps ...float64) bool {
169179
}
170180

171181
mmd := m.Apply(m.Dagger())
172-
id := Identity(m.Dimension())
182+
id := Identity(m.Rows)
173183
return mmd.Equals(id, epsilon.E13(eps...))
174184
}
175185

@@ -181,13 +191,11 @@ func (m Matrix) Apply(n Matrix) Matrix {
181191

182192
out := Zero(a, p)
183193
for i := range a {
184-
for j := range p {
185-
var c complex128
186-
for k := range b {
187-
c = c + n.At(i, k)*m.At(k, j)
194+
for k := range b {
195+
nik := n.At(i, k)
196+
for j := range p {
197+
out.AddAt(i, j, nik*m.At(k, j))
188198
}
189-
190-
out.Set(i, j, c)
191199
}
192200
}
193201

@@ -276,12 +284,12 @@ func (m Matrix) Inverse() Matrix {
276284
p, q := m.Dimension()
277285
mm := m.Clone()
278286

279-
out := Identity(p, q)
287+
out := Identity(p)
280288
for i := range p {
281289
c := 1 / mm.At(i, i)
282290
for j := range q {
283-
mm.Set(i, j, c*mm.At(i, j))
284-
out.Set(i, j, c*out.At(i, j))
291+
mm.MulAt(i, j, c)
292+
out.MulAt(i, j, c)
285293
}
286294

287295
for j := range q {
@@ -291,8 +299,8 @@ func (m Matrix) Inverse() Matrix {
291299

292300
c := mm.At(j, i)
293301
for k := range q {
294-
mm.Set(j, k, mm.At(j, k)-c*mm.At(i, k))
295-
out.Set(j, k, out.At(j, k)-c*out.At(i, k))
302+
mm.AddAt(j, k, -c*mm.At(i, k))
303+
out.AddAt(j, k, -c*out.At(i, k))
296304
}
297305
}
298306
}
@@ -306,14 +314,14 @@ func (m Matrix) TensorProduct(n Matrix) Matrix {
306314
a, b := n.Dimension()
307315
rows, cols := p*a, q*b
308316

309-
var idx int
310317
data := make([]complex128, rows*cols)
311318
for i := range p {
312-
for k := range a {
313-
for j := range q {
319+
for j := range q {
320+
mij := m.At(i, j)
321+
for k := range a {
314322
for l := range b {
315-
data[idx] = m.At(i, j) * n.At(k, l)
316-
idx++
323+
row, col := i*a+k, j*b+l
324+
data[row*cols+col] = mij * n.At(k, l)
317325
}
318326
}
319327
}

math/vector/vector.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ func (v Vector) Mul(z complex128) Vector {
6666
func (v Vector) TensorProduct(w Vector) Vector {
6767
p, q := len(v), len(w)
6868

69-
var idx int
7069
out := make(Vector, p*q)
7170
for i := range p {
7271
for j := range q {
73-
out[idx] = v[i] * w[j]
74-
idx++
72+
out[i*q+j] = v[i] * w[j]
7573
}
7674
}
7775

@@ -92,21 +90,19 @@ func (v Vector) InnerProduct(w Vector) complex128 {
9290

9391
// OuterProduct returns the outer product of v and w.
9492
func (v Vector) OuterProduct(w Vector) matrix.Matrix {
95-
p, q := len(v), len(w)
93+
rows, cols := len(v), len(w)
9694
dual := w.Dual()
9795

98-
var idx int
99-
data := make([]complex128, p*q)
96+
data := make([]complex128, rows*cols)
10097
for i := range v {
10198
for j := range dual {
102-
data[idx] = v[i] * dual[j]
103-
idx++
99+
data[i*cols+j] = v[i] * dual[j]
104100
}
105101
}
106102

107103
return matrix.Matrix{
108-
Rows: p,
109-
Cols: q,
104+
Rows: rows,
105+
Cols: cols,
110106
Data: data,
111107
}
112108
}

quantum/density/matrix.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func New(ensemble []State) *Matrix {
3939
op := s.Qubit.OuterProduct(s.Qubit).Mul(complex(s.Probability, 0))
4040
for i := range n {
4141
for j := range n {
42-
m.Set(i, j, m.At(i, j)+op.At(i, j))
42+
m.AddAt(i, j, op.At(i, j))
4343
}
4444
}
4545
}
@@ -129,7 +129,7 @@ func (m *Matrix) PartialTrace(index ...Qubit) (*Matrix, error) {
129129

130130
r := int(number.Must(strconv.ParseInt(kr, 2, 0)))
131131
c := int(number.Must(strconv.ParseInt(lr, 2, 0)))
132-
out.Set(r, c, out.At(r, c)+m.m.At(i, j))
132+
out.AddAt(r, c, m.m.At(i, j))
133133

134134
// fmt.Printf("[%v][%v] = [%v][%v] + [%v][%v]\n", r, c, r, c, i, j)
135135
//
@@ -160,7 +160,9 @@ func (m *Matrix) Depolarizing(p float64) (*Matrix, error) {
160160
n := m.NumQubits()
161161
i := gate.I(n).Mul(complex(p/2, 0))
162162
r := m.m.Mul(complex(1-p, 0))
163-
return &Matrix{i.Add(r)}, nil
163+
return &Matrix{
164+
m: i.Add(r),
165+
}, nil
164166
}
165167

166168
func take(n, i int, index []Qubit) (string, string) {

quantum/gate/gate.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ func Controlled(u matrix.Matrix, n int, c []int, t int) matrix.Matrix {
139139

140140
s := (1 << n)
141141
g := I(n)
142-
for i := 0; i < s; i++ {
142+
for i := range s {
143143
if (i & mask) != mask {
144144
continue
145145
}
146146

147-
for j := 0; j < s; j++ {
147+
for j := range s {
148148
if (j & mask) != mask {
149149
continue
150150
}
@@ -215,9 +215,9 @@ func ControlledZ(n int, c []int, t int) matrix.Matrix {
215215
}
216216

217217
g := I(n)
218-
for i := 0; i < (1 << n); i++ {
218+
for i := range 1 << n {
219219
if (i&mask) == mask && (i&(1<<(n-1-t))) != 0 {
220-
g.Set(i, i, -1*g.At(i, i))
220+
g.MulAt(i, i, -1)
221221
}
222222
}
223223

@@ -237,9 +237,9 @@ func ControlledS(n int, c []int, t int) matrix.Matrix {
237237
}
238238

239239
g := I(n)
240-
for i := 0; i < (1 << n); i++ {
240+
for i := range 1 << n {
241241
if (i&mask) == mask && (i&(1<<(n-1-t))) != 0 {
242-
g.Set(i, i, 1i*g.At(i, i))
242+
g.MulAt(i, i, 1i)
243243
}
244244
}
245245

@@ -262,9 +262,9 @@ func ControlledR(theta float64, n int, c []int, t int) matrix.Matrix {
262262
}
263263

264264
g := I(n)
265-
for i := 0; i < (1 << n); i++ {
265+
for i := range 1 << n {
266266
if (i&mask) == mask && (i&(1<<(n-1-t))) != 0 {
267-
g.Set(i, i, e*g.At(i, i))
267+
g.MulAt(i, i, e)
268268
}
269269
}
270270

0 commit comments

Comments
 (0)