Skip to content

Commit 458aa86

Browse files
[tests] improve basis tests
1 parent 32986ef commit 458aa86

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

tests/test_basis.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import TYPE_CHECKING, Any
2+
13
import numpy as np
24
import pytest
3-
from rydstate import BasisSQDTAlkali
4-
from rydstate.basis.basis_sqdt import BasisSQDTAlkalineLS
5+
from rydstate import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS
6+
7+
if TYPE_CHECKING:
8+
from rydstate.basis.basis_base import BasisBase
59

610

711
@pytest.mark.parametrize("species_name", ["Rb", "Na", "H"])
@@ -62,3 +66,13 @@ def test_alkaline_basis(species_name: str) -> None:
6266
me_matrix = basis.calc_reduced_matrix_elements(basis, "electric_dipole", unit="e a0")
6367
assert np.shape(me_matrix) == (len(basis.states), len(basis.states))
6468
assert np.count_nonzero(me_matrix) > 0
69+
70+
basis = BasisSQDTAlkalineLS(species_name, n_min=30, n_max=35)
71+
basis.filter_states("l_r", (6, 10))
72+
for basis_class in [BasisSQDTAlkalineJJ, BasisSQDTAlkalineFJ]:
73+
basis2: BasisBase[Any] = basis_class(species_name, n_min=30, n_max=35) # type: ignore [assignment]
74+
basis2.filter_states("l_r", (6, 10))
75+
assert len(basis2.states) == len(basis.states)
76+
trafo = basis.calc_reduced_overlaps(basis2)
77+
trafo_inv = basis2.calc_reduced_overlaps(basis)
78+
assert np.allclose(trafo @ trafo_inv, np.eye(len(basis.states)), atol=1e-3)

0 commit comments

Comments
 (0)