@@ -520,3 +520,55 @@ def test_check_canonical_form():
520520 # We'll just check that it returns something.
521521 # If it prints "MPS is right (B) canonical." for example, that's okay.
522522 assert res is not None
523+
524+ def test_convert_to_vector ():
525+ """
526+ Tests the MPS_to_vector function for various initial states.
527+ For each state, the expected full state vector is computed as the tensor
528+ product of the corresponding local state vectors.
529+ """
530+ test_states = ["zeros" , "ones" , "x+" , "x-" , "y+" , "y-" , "Neel" , "wall" ]
531+ L = 4 # Use a small number of sites for testing.
532+ tol = 1e-12
533+
534+ def local_state_vector (state_str : str , index : int , L : int ) -> np .ndarray :
535+ """
536+ Returns the local state vector for a given state string.
537+ For 'Neel' and 'wall', the local state depends on the site index.
538+ """
539+ if state_str == "zeros" :
540+ return np .array ([1 , 0 ], dtype = complex )
541+ elif state_str == "ones" :
542+ return np .array ([0 , 1 ], dtype = complex )
543+ elif state_str == "x+" :
544+ return np .array ([1 / np .sqrt (2 ), 1 / np .sqrt (2 )], dtype = complex )
545+ elif state_str == "x-" :
546+ return np .array ([1 / np .sqrt (2 ), - 1 / np .sqrt (2 )], dtype = complex )
547+ elif state_str == "y+" :
548+ return np .array ([1 / np .sqrt (2 ), 1j / np .sqrt (2 )], dtype = complex )
549+ elif state_str == "y-" :
550+ return np .array ([1 / np .sqrt (2 ), - 1j / np .sqrt (2 )], dtype = complex )
551+ elif state_str == "Neel" :
552+ # According to the MPS code: if index is odd, local vector = [1, 0]; if even, [0, 1].
553+ return np .array ([1 , 0 ], dtype = complex ) if index % 2 == 1 else np .array ([0 , 1 ], dtype = complex )
554+ elif state_str == "wall" :
555+ # For a "wall" state: sites with index < L//2 are |0>, else |1>.
556+ return np .array ([1 , 0 ], dtype = complex ) if index < L // 2 else np .array ([0 , 1 ], dtype = complex )
557+ else :
558+ raise ValueError ("Invalid state string" )
559+
560+ for state_str in test_states :
561+ # Create an MPS for the given state.
562+ mps = MPS (length = L , state = state_str )
563+ psi = mps .convert_to_vector ()
564+
565+ # Construct the expected state vector as the Kronecker product of local states.
566+ local_states = [local_state_vector (state_str , i , L ) for i in range (L )]
567+ expected = reduce (np .kron , local_states )
568+
569+ if np .allclose (psi , expected , atol = tol ):
570+ print (f"Test passed for state '{ state_str } '." )
571+ else :
572+ print (f"Test FAILED for state '{ state_str } '." )
573+ print ("Expected:" , expected )
574+ print ("Got :" , psi )
0 commit comments