Skip to content

Commit 29b7369

Browse files
start adding mqdt states and basis
1 parent 0aba070 commit 29b7369

9 files changed

Lines changed: 274 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ dependencies = [
4444
]
4545

4646
[project.optional-dependencies]
47+
mqdt = [
48+
"juliacall >= 0.9.24",
49+
]
4750
tests = [
4851
"pytest >= 8.0",
4952
"nbmake >= 1.3",

src/rydstate/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from rydstate import angular, basis, radial, rydberg, species
22
from rydstate.basis import (
3+
BasisMQDT,
34
BasisSQDTAlkali,
45
BasisSQDTAlkalineFJ,
56
BasisSQDTAlkalineJJ,
67
BasisSQDTAlkalineLS,
78
)
89
from rydstate.rydberg import (
10+
RydbergStateMQDT,
911
RydbergStateSQDT,
1012
RydbergStateSQDTAlkali,
1113
RydbergStateSQDTAlkalineFJ,
@@ -15,10 +17,12 @@
1517
from rydstate.units import ureg
1618

1719
__all__ = [
20+
"BasisMQDT",
1821
"BasisSQDTAlkali",
1922
"BasisSQDTAlkalineFJ",
2023
"BasisSQDTAlkalineJJ",
2124
"BasisSQDTAlkalineLS",
25+
"RydbergStateMQDT",
2226
"RydbergStateSQDT",
2327
"RydbergStateSQDTAlkali",
2428
"RydbergStateSQDTAlkalineFJ",

src/rydstate/angular/angular_ket.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from rydstate.species import SpeciesObject
2424

2525
if TYPE_CHECKING:
26+
import juliacall
2627
from typing_extensions import Self
2728

2829
from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType
@@ -738,6 +739,17 @@ def sanity_check(self, msgs: list[str] | None = None) -> None:
738739
super().sanity_check(msgs)
739740

740741

742+
def julia_qn_to_dict(qn: juliacall.AnyValue) -> dict[str, float]:
743+
"""Convert MQDT Julia quantum numbers to dict object."""
744+
if "fjQuantumNumbers" in str(qn):
745+
return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, f_c=qn.Fc, l_r=qn.lr, j_r=qn.Jr, f_tot=qn.F) # noqa: C408
746+
if "jjQuantumNumbers" in str(qn):
747+
return dict(s_c=qn.sc, l_c=qn.lc, j_c=qn.Jc, l_r=qn.lr, j_r=qn.Jr, j_tot=qn.J, f_tot=qn.F) # noqa: C408
748+
if "lsQuantumNumbers" in str(qn):
749+
return dict(s_c=qn.sc, s_tot=qn.S, l_c=qn.lc, l_r=qn.lr, l_tot=qn.L, j_tot=qn.J, f_tot=qn.F) # noqa: C408
750+
raise ValueError(f"Unknown MQDT Julia quantum numbers {qn!s}.")
751+
752+
741753
def quantum_numbers_to_angular_ket(
742754
species: str | SpeciesObject,
743755
s_c: float | None = None,

src/rydstate/angular/angular_state.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
) -> None:
3535
self.coefficients = np.array(coefficients)
3636
self.kets = kets
37+
self._warn_if_not_normalized = warn_if_not_normalized
3738

3839
if len(coefficients) != len(kets):
3940
raise ValueError("Length of coefficients and kets must be the same.")
@@ -98,7 +99,8 @@ def to(self, coupling_scheme: CouplingScheme) -> AngularState[Any]:
9899
else:
99100
kets.append(scheme_ket)
100101
coefficients.append(coeff * scheme_coeff)
101-
return AngularState(coefficients, kets, warn_if_not_normalized=abs(self.norm - 1) < 1e-10)
102+
warn_if_not_normalized = self._warn_if_not_normalized and (abs(self.norm - 1) < 1e-10)
103+
return AngularState(coefficients, kets, warn_if_not_normalized=warn_if_not_normalized)
102104

103105
def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float:
104106
"""Calculate the expectation value of a quantum number q.

src/rydstate/basis/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
from rydstate.basis.basis_base import BasisBase
2+
from rydstate.basis.basis_mqdt import BasisMQDT
23
from rydstate.basis.basis_sqdt import BasisSQDTAlkali, BasisSQDTAlkalineFJ, BasisSQDTAlkalineJJ, BasisSQDTAlkalineLS
34

4-
__all__ = ["BasisBase", "BasisSQDTAlkali", "BasisSQDTAlkalineFJ", "BasisSQDTAlkalineJJ", "BasisSQDTAlkalineLS"]
5+
__all__ = [
6+
"BasisBase",
7+
"BasisMQDT",
8+
"BasisSQDTAlkali",
9+
"BasisSQDTAlkalineFJ",
10+
"BasisSQDTAlkalineJJ",
11+
"BasisSQDTAlkalineLS",
12+
]

src/rydstate/basis/basis_mqdt.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import TYPE_CHECKING, Any
5+
6+
from rydstate.angular.angular_ket import julia_qn_to_dict
7+
from rydstate.basis.basis_base import BasisBase
8+
from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT
9+
from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT
10+
11+
if TYPE_CHECKING:
12+
from rydstate.species import SpeciesObject
13+
14+
logger = logging.getLogger(__name__)
15+
16+
try:
17+
USE_JULIACALL = True
18+
from juliacall import (
19+
JuliaError,
20+
Main as jl, # noqa: N813
21+
convert,
22+
)
23+
except ImportError:
24+
USE_JULIACALL = False
25+
26+
27+
if USE_JULIACALL:
28+
try:
29+
jl.seval("using MQDT")
30+
jl.seval("using CGcoefficient")
31+
except JuliaError:
32+
logger.exception("Failed to load Julia MQDT or CGcoefficient package")
33+
USE_JULIACALL = False
34+
35+
FMODEL_MAX_L = {"Sr87": 2, "Sr88": 2, "Yb171": 4, "Yb173": 1, "Yb174": 4}
36+
37+
38+
class BasisMQDT(BasisBase[RydbergStateMQDT[Any]]):
39+
def __init__(
40+
self,
41+
species: str | SpeciesObject,
42+
n_min: int = 0,
43+
n_max: int | None = None,
44+
*,
45+
skip_high_l: bool = True,
46+
model_names: list[str] | None = None,
47+
) -> None:
48+
super().__init__(species)
49+
50+
if not USE_JULIACALL:
51+
raise ImportError("JuliaCall or the MQDT Julia package is not available.")
52+
53+
try:
54+
self.jl_species = getattr(jl.MQDT, self.species.name)
55+
parameters = self.jl_species.PARA
56+
except AttributeError as e:
57+
raise ValueError(f"Species '{species}' is not supported in the MQDT Julia package.") from e
58+
59+
# TODO use n_min and n_max of the different models
60+
61+
if n_max is None:
62+
raise ValueError("n_max must be given")
63+
64+
# initialize Wigner symbol calculation
65+
if skip_high_l:
66+
jl.CGcoefficient.wigner_init_float(5, "Jmax", 9)
67+
else:
68+
jl.CGcoefficient.wigner_init_float(n_max - 1, "Jmax", 9)
69+
70+
logger.debug("Calculating low l MQDT states...")
71+
72+
jl_species_attr_names = [str(name) for name in jl.seval(f"names(MQDT.{self.species.name}, all=true)")]
73+
self.models = {name: getattr(self.jl_species, name) for name in jl_species_attr_names}
74+
self.models = {k: v for k, v in self.models.items() if str(v).startswith("fModel")}
75+
if model_names is not None:
76+
self.models = {k: v for k, v in self.models.items() if k in model_names}
77+
78+
if skip_high_l:
79+
logger.debug("Skipping high l states.")
80+
else:
81+
logger.debug("Calculating high l SQDT states...")
82+
l_start = FMODEL_MAX_L[self.species.name] + 1
83+
high_l_models = {
84+
f"high_l_{l_ryd}": jl.single_channel_models(species, l_ryd, parameters)
85+
for l_ryd in range(l_start, n_max)
86+
}
87+
self.models.update(high_l_models)
88+
89+
model_names = list(self.models.keys())
90+
jl_states = {name: jl.eigenstates(n_min, n_max, model, parameters) for name, model in self.models.items()}
91+
_models_vector = convert(jl.Vector, [self.models[name] for name in model_names])
92+
_jl_states_vector = convert(jl.Vector, [jl_states[name] for name in model_names])
93+
jl_basis = jl.basisarray(_jl_states_vector, _models_vector)
94+
95+
logger.debug("Generated state table with %d states", len(jl_basis.states))
96+
97+
self.states = []
98+
for jl_state in jl_basis.states:
99+
coeffs = jl_state.coeff
100+
nus = jl_state.nu
101+
nu_energy = jl_state.energy
102+
qns = jl_state.channels.i
103+
qns = [julia_qn_to_dict(qn) for qn in qns]
104+
105+
sqdt_states = [RydbergStateSQDT(species, nu=nu, **qn) for nu, qn in zip(nus, qns)]
106+
# check angular and radial are created correctly
107+
[(s.angular, s.radial) for s in sqdt_states]
108+
109+
mqdt_state = RydbergStateMQDT(coeffs, sqdt_states, nu_energy=nu_energy, warn_if_not_normalized=False)
110+
self.states.append(mqdt_state)

src/rydstate/rydberg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from rydstate.rydberg.rydberg_base import RydbergStateBase
2+
from rydstate.rydberg.rydberg_mqdt import RydbergStateMQDT
23
from rydstate.rydberg.rydberg_sqdt import (
34
RydbergStateSQDT,
45
RydbergStateSQDTAlkali,
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"RydbergStateBase",
13+
"RydbergStateMQDT",
1214
"RydbergStateSQDT",
1315
"RydbergStateSQDTAlkali",
1416
"RydbergStateSQDTAlkalineFJ",
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
5+
6+
import numpy as np
7+
8+
from rydstate.angular import AngularState
9+
from rydstate.rydberg.rydberg_base import RydbergStateBase
10+
from rydstate.rydberg.rydberg_sqdt import RydbergStateSQDT
11+
12+
if TYPE_CHECKING:
13+
from collections.abc import Iterator, Sequence
14+
15+
from rydstate.units import MatrixElementOperator, PintFloat
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
_RydbergState = TypeVar("_RydbergState", bound=RydbergStateSQDT)
22+
23+
24+
class RydbergStateMQDT(RydbergStateBase, Generic[_RydbergState]):
25+
angular: AngularState[Any]
26+
"""Return the angular part of the MQDT state as an AngularState."""
27+
28+
def __init__(
29+
self,
30+
coefficients: Sequence[float],
31+
sqdt_states: Sequence[_RydbergState],
32+
*,
33+
nu_energy: float | None = None,
34+
warn_if_not_normalized: bool = True,
35+
) -> None:
36+
self.coefficients = np.array(coefficients)
37+
self.sqdt_states = sqdt_states
38+
self.nu_energy = nu_energy
39+
self.angular = AngularState(self.coefficients.tolist(), [ket.angular for ket in sqdt_states])
40+
41+
if len(coefficients) != len(sqdt_states):
42+
raise ValueError("Length of coefficients and sqdt_states must be the same.")
43+
if not all(type(sqdt_state) is type(sqdt_states[0]) for sqdt_state in sqdt_states):
44+
raise ValueError("All sqdt_states must be of the same type.")
45+
if len(set(sqdt_states)) != len(sqdt_states):
46+
raise ValueError("RydbergStateMQDT initialized with duplicate sqdt_states.")
47+
if abs(self.norm - 1) > 1e-10 and warn_if_not_normalized:
48+
logger.warning(
49+
"RydbergStateMQDT initialized with non-normalized coefficients "
50+
"(norm=%s, coefficients=%s, sqdt_states=%s)",
51+
self.norm,
52+
coefficients,
53+
sqdt_states,
54+
)
55+
if self.norm > 1:
56+
self.coefficients /= self.norm
57+
58+
def __iter__(self) -> Iterator[tuple[float, _RydbergState]]:
59+
return zip(self.coefficients, self.sqdt_states).__iter__()
60+
61+
def __repr__(self) -> str:
62+
terms = [f"{coeff}*{sqdt_state!r}" for coeff, sqdt_state in self]
63+
return f"{self.__class__.__name__}({', '.join(terms)})"
64+
65+
def __str__(self) -> str:
66+
terms = [f"{coeff}*{sqdt_state!s}" for coeff, sqdt_state in self]
67+
return f"{', '.join(terms)}"
68+
69+
@property
70+
def norm(self) -> float:
71+
"""Return the norm of the state (should be 1)."""
72+
return np.linalg.norm(self.coefficients) # type: ignore [return-value]
73+
74+
def calc_reduced_overlap(self, other: RydbergStateBase) -> float:
75+
"""Calculate the reduced overlap <self|other> (ignoring the magnetic quantum number m)."""
76+
if isinstance(other, RydbergStateSQDT):
77+
other = other.to_mqdt()
78+
79+
if isinstance(other, RydbergStateMQDT):
80+
ov = 0
81+
for coeff1, sqdt1 in self:
82+
for coeff2, sqdt2 in other:
83+
ov += np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_overlap(sqdt2)
84+
return ov
85+
86+
raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}")
87+
88+
@overload # type: ignore [override]
89+
def calc_reduced_matrix_element(
90+
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: None = None
91+
) -> PintFloat: ...
92+
93+
@overload
94+
def calc_reduced_matrix_element(
95+
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str
96+
) -> float: ...
97+
98+
def calc_reduced_matrix_element(
99+
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None
100+
) -> PintFloat | float:
101+
r"""Calculate the reduced angular matrix element.
102+
103+
This means, calculate the following matrix element:
104+
105+
.. math::
106+
\left\langle self || \hat{O}^{(\kappa)} || other \right\rangle
107+
108+
"""
109+
if isinstance(other, RydbergStateSQDT):
110+
other = other.to_mqdt()
111+
112+
if isinstance(other, RydbergStateMQDT):
113+
value = 0
114+
for coeff1, sqdt1 in self:
115+
for coeff2, sqdt2 in other:
116+
value += (
117+
np.conjugate(coeff1) * coeff2 * sqdt1.calc_reduced_matrix_element(sqdt2, operator, unit=unit)
118+
)
119+
return value
120+
121+
raise NotImplementedError(f"calc_reduced_overlap not implemented for {type(self)=}, {type(other)=}")

src/rydstate/rydberg/rydberg_sqdt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import math
55
from functools import cached_property
6-
from typing import TYPE_CHECKING, overload
6+
from typing import TYPE_CHECKING, Any, overload
77

88
import numpy as np
99

@@ -15,6 +15,7 @@
1515
from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg
1616

1717
if TYPE_CHECKING:
18+
from rydstate import RydbergStateMQDT
1819
from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS
1920
from rydstate.units import MatrixElementOperator, PintFloat
2021

@@ -182,10 +183,16 @@ def get_energy(self, unit: str | None = None) -> PintFloat | float:
182183
return energy
183184
return energy.to(unit, "spectroscopy").magnitude
184185

186+
def to_mqdt(self) -> RydbergStateMQDT[Any]:
187+
"""Convert to a trivial RydbergMQDT state with only one contribution with coefficient 1."""
188+
from rydstate import RydbergStateMQDT # noqa: PLC0415
189+
190+
return RydbergStateMQDT([1], [self])
191+
185192
def calc_reduced_overlap(self, other: RydbergStateBase) -> float:
186193
"""Calculate the reduced overlap <self|other> (ignoring the magnetic quantum number m)."""
187194
if not isinstance(other, RydbergStateSQDT):
188-
raise NotImplementedError("Reduced overlap only implemented between RydbergStateSQDT states.")
195+
return self.to_mqdt().calc_reduced_overlap(other)
189196

190197
radial_overlap = self.radial.calc_overlap(other.radial)
191198
angular_overlap = self.angular.calc_reduced_overlap(other.angular)
@@ -226,7 +233,7 @@ def calc_reduced_matrix_element(
226233
227234
"""
228235
if not isinstance(other, RydbergStateSQDT):
229-
raise NotImplementedError("Reduced matrix element only implemented between RydbergStateSQDT states.")
236+
return self.to_mqdt().calc_reduced_matrix_element(other, operator, unit=unit)
230237

231238
if operator not in MatrixElementOperatorRanks:
232239
raise ValueError(

0 commit comments

Comments
 (0)