Skip to content

Commit e7815d2

Browse files
Added test for convert_to_vector
1 parent c8d6462 commit e7815d2

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/core/data_structures/test_networks.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,55 @@ def test_check_canonical_form():
520520
# We'll just check that it returns something.
521521
# If it prints "MPS is right (B) canonical." for example, that's okay.
522522
assert res is not None
523+
524+
def test_convert_to_vector():
525+
"""
526+
Tests the MPS_to_vector function for various initial states.
527+
For each state, the expected full state vector is computed as the tensor
528+
product of the corresponding local state vectors.
529+
"""
530+
test_states = ["zeros", "ones", "x+", "x-", "y+", "y-", "Neel", "wall"]
531+
L = 4 # Use a small number of sites for testing.
532+
tol = 1e-12
533+
534+
def local_state_vector(state_str: str, index: int, L: int) -> np.ndarray:
535+
"""
536+
Returns the local state vector for a given state string.
537+
For 'Neel' and 'wall', the local state depends on the site index.
538+
"""
539+
if state_str == "zeros":
540+
return np.array([1, 0], dtype=complex)
541+
elif state_str == "ones":
542+
return np.array([0, 1], dtype=complex)
543+
elif state_str == "x+":
544+
return np.array([1/np.sqrt(2), 1/np.sqrt(2)], dtype=complex)
545+
elif state_str == "x-":
546+
return np.array([1/np.sqrt(2), -1/np.sqrt(2)], dtype=complex)
547+
elif state_str == "y+":
548+
return np.array([1/np.sqrt(2), 1j/np.sqrt(2)], dtype=complex)
549+
elif state_str == "y-":
550+
return np.array([1/np.sqrt(2), -1j/np.sqrt(2)], dtype=complex)
551+
elif state_str == "Neel":
552+
# According to the MPS code: if index is odd, local vector = [1, 0]; if even, [0, 1].
553+
return np.array([1, 0], dtype=complex) if index % 2 == 1 else np.array([0, 1], dtype=complex)
554+
elif state_str == "wall":
555+
# For a "wall" state: sites with index < L//2 are |0>, else |1>.
556+
return np.array([1, 0], dtype=complex) if index < L // 2 else np.array([0, 1], dtype=complex)
557+
else:
558+
raise ValueError("Invalid state string")
559+
560+
for state_str in test_states:
561+
# Create an MPS for the given state.
562+
mps = MPS(length=L, state=state_str)
563+
psi = mps.convert_to_vector()
564+
565+
# Construct the expected state vector as the Kronecker product of local states.
566+
local_states = [local_state_vector(state_str, i, L) for i in range(L)]
567+
expected = reduce(np.kron, local_states)
568+
569+
if np.allclose(psi, expected, atol=tol):
570+
print(f"Test passed for state '{state_str}'.")
571+
else:
572+
print(f"Test FAILED for state '{state_str}'.")
573+
print("Expected:", expected)
574+
print("Got :", psi)

0 commit comments

Comments
 (0)