Skip to content

Commit d6a5bf0

Browse files
committed
Update PartialTrace
1 parent bc1d1b8 commit d6a5bf0

File tree

1 file changed

+46
-38
lines changed

1 file changed

+46
-38
lines changed

quantum/density/matrix.go

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package density
33
import (
44
"iter"
55
"math"
6-
"strings"
76

87
"github.com/itsubaki/q/math/epsilon"
98
"github.com/itsubaki/q/math/matrix"
@@ -139,42 +138,33 @@ func (m *Matrix) TensorProduct(n *Matrix) *Matrix {
139138
// where n is the number of qubits in the matrix.
140139
func (m *Matrix) PartialTrace(qb ...int) *Matrix {
141140
n := m.NumQubits()
142-
d := number.Pow(2, n-len(qb))
143-
p, q := m.Dim()
144141

142+
// mask for the qubits to be traced out
143+
var mask int
144+
for _, q := range qb {
145+
mask |= 1 << (n - 1 - q)
146+
}
147+
148+
p, q := m.Dim()
149+
d := 1 << (n - len(qb))
145150
rho := matrix.Zero(d, d)
146151
for i := range p {
147-
k, kr := take(n, i, qb)
152+
ti, ki := split(i, n, mask)
148153

149154
for j := range q {
150-
l, lr := take(n, j, qb)
155+
tj, kj := split(j, n, mask)
151156

152-
if k != l {
157+
if ti != tj {
153158
continue
154159
}
155160

156-
r := number.MustParseInt(kr)
157-
c := number.MustParseInt(lr)
158-
rho.AddAt(r, c, m.At(i, j))
159-
160-
// fmt.Printf("[%v][%v] = [%v][%v] + [%v][%v]\n", r, c, r, c, i, j)
161-
//
162-
// 4x4 explicit
163-
// index -> 0
164-
// out[0][0] = m.m[0][0] + m.m[2][2]
165-
// out[0][1] = m.m[0][1] + m.m[2][3]
166-
// out[1][0] = m.m[1][0] + m.m[3][2]
167-
// out[1][1] = m.m[1][1] + m.m[3][3]
168-
//
169-
// index -> 1
170-
// out[0][0] = m.m[0][0] + m.m[1][1]
171-
// out[0][1] = m.m[0][2] + m.m[1][3]
172-
// out[1][0] = m.m[2][0] + m.m[3][1]
173-
// out[1][1] = m.m[2][2] + m.m[3][3]
161+
rho.AddAt(ki, kj, m.At(i, j))
174162
}
175163
}
176164

177-
return &Matrix{rho: rho}
165+
return &Matrix{
166+
rho: rho,
167+
}
178168
}
179169

180170
// Depolarizing returns the depolarizing channel.
@@ -233,22 +223,40 @@ func (m *Matrix) PhaseFlip(p float64, qb int) *Matrix {
233223
return m.ApplyChannel(p, gate.Z(), qb)
234224
}
235225

236-
func take(n, i int, qb []int) (string, string) {
237-
target := make(map[int]struct{}, len(qb))
238-
for _, j := range qb {
239-
target[j] = struct{}{}
240-
}
241-
242-
var out, remain strings.Builder
243-
for j := range n {
244-
s := byte('0' + ((i >> (n - 1 - j)) & 1))
245-
if _, ok := target[j]; ok {
246-
out.WriteByte(s)
226+
// split separates the bits of x into two integers according to mask.
227+
//
228+
// Bits where mask has value 1 are extracted into the returned trace value,
229+
// preserving their relative order. Bits where mask has value 0 are extracted
230+
// into the returned kept value.
231+
//
232+
// The n parameter specifies the number of bits of x to consider.
233+
//
234+
// For example:
235+
//
236+
// n = 3
237+
// x = 0b101
238+
// mask = 0b010
239+
//
240+
// The bit at position 1 is traced out, so:
241+
//
242+
// trace = 0b0
243+
// kept = 0b11
244+
//
245+
// This helper is used when computing partial traces of density matrices.
246+
func split(x, n, mask int) (int, int) {
247+
var trace, kept, trPos, kpPos int
248+
for i := range n {
249+
bit := (x >> i) & 1
250+
251+
if (mask>>i)&1 == 1 {
252+
trace |= bit << trPos
253+
trPos++
247254
continue
248255
}
249256

250-
remain.WriteByte(s)
257+
kept |= bit << kpPos
258+
kpPos++
251259
}
252260

253-
return out.String(), remain.String()
261+
return trace, kept
254262
}

0 commit comments

Comments
 (0)