Skip to content

Commit e7bea6f

Browse files
N McDowallN McDowall
authored andcommitted
Add phase-agnostic state comparator and tests (fixes #304)
1 parent f2a2c21 commit e7bea6f

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
from projectq import MainEngine
3+
from projectq.ops import H, CNOT
4+
from projectq.tests.helpers import PhaseAgnosticStateComparator
5+
6+
def test_hadamard_twice():
7+
eng = MainEngine()
8+
q = eng.allocate_qureg(1)
9+
H | q[0]
10+
H | q[0]
11+
eng.flush()
12+
_, actual = eng.backend.cheat()
13+
expected = np.array([1, 0], dtype=complex)
14+
comparator = PhaseAgnosticStateComparator()
15+
comparator.compare(actual, expected)
16+
17+
def test_bell_state():
18+
eng = MainEngine()
19+
q = eng.allocate_qureg(2)
20+
H | q[0]
21+
CNOT | (q[0], q[1])
22+
eng.flush()
23+
_, actual = eng.backend.cheat()
24+
expected = np.array([1/np.sqrt(2), 0, 0, 1/np.sqrt(2)], dtype=complex)
25+
comparator = PhaseAgnosticStateComparator()
26+
comparator.compare(actual, expected)

projectq/tests/helpers.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
3+
class PhaseAgnosticStateComparator:
4+
def __init__(self, tol=1e-8):
5+
self.tol = tol
6+
7+
def _validate_state(self, state, label="input"):
8+
if not isinstance(state, np.ndarray):
9+
raise TypeError(f"{label} must be a NumPy array.")
10+
if state.ndim not in (1, 2):
11+
raise ValueError(f"{label} must be a 1D or 2D array, got shape {state.shape}.")
12+
if not np.isfinite(state).all():
13+
raise ValueError(f"{label} contains NaN or Inf.")
14+
if state.ndim == 1 and np.linalg.norm(state) == 0:
15+
raise ValueError(f"{label} is a zero vector and cannot be normalized.")
16+
if state.ndim == 2:
17+
for i, vec in enumerate(state):
18+
if np.linalg.norm(vec) == 0:
19+
raise ValueError(f"{label}[{i}] is a zero vector.")
20+
21+
def normalize(self, state, label="input"):
22+
self._validate_state(state, label)
23+
if state.ndim == 1:
24+
return state / np.linalg.norm(state)
25+
else:
26+
return np.array([vec / np.linalg.norm(vec) for vec in state])
27+
28+
def align_phase(self, actual, expected):
29+
a = self.normalize(actual, "actual")
30+
b = self.normalize(expected, "expected")
31+
if a.ndim == 1:
32+
phase = np.vdot(b, a)
33+
return actual * phase.conjugate()
34+
else:
35+
aligned = []
36+
for i in range(a.shape[0]):
37+
phase = np.vdot(b[i], a[i])
38+
aligned_vec = actual[i] * phase.conjugate()
39+
aligned.append(aligned_vec)
40+
return np.array(aligned)
41+
42+
def compare(self, actual, expected):
43+
aligned = self.align_phase(actual, expected)
44+
if actual.ndim == 1:
45+
if not np.allclose(aligned, expected, atol=self.tol):
46+
raise AssertionError("States differ beyond tolerance.")
47+
else:
48+
for i in range(actual.shape[0]):
49+
if not np.allclose(aligned[i], expected[i], atol=self.tol):
50+
raise AssertionError(
51+
f"State {i} differs:\nAligned: {aligned[i]}\nExpected: {expected[i]}"
52+
)

0 commit comments

Comments
 (0)