Skip to content

Commit 394981d

Browse files
various small improvements
1 parent 1bd84f2 commit 394981d

7 files changed

Lines changed: 89 additions & 90 deletions

File tree

src/rydstate/angular/angular_ket.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
try_trivial_spin_addition,
2121
)
2222
from rydstate.angular.wigner_symbols import calc_wigner_3j, clebsch_gordan_6j, clebsch_gordan_9j
23-
from rydstate.species import SpeciesObject
2423

2524
if TYPE_CHECKING:
25+
from collections.abc import Sequence
26+
2627
from typing_extensions import Self
2728

2829
from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType
2930
from rydstate.angular.angular_state import AngularState
3031
from rydstate.angular.utils import CouplingScheme
32+
from rydstate.species import SpeciesObject
3133

3234
logger = logging.getLogger(__name__)
3335

@@ -83,12 +85,13 @@ def __init__(
8385
) -> None:
8486
"""Initialize the Spin ket.
8587
86-
species:
87-
Atomic species, e.g. 'Rb87'.
88-
Not used for calculation, only for convenience to infer the core electron spin and nuclear spin quantum numbers.
88+
Atomic species, e.g. 'Rb87', will not be used for calculation,
89+
only for convenience to infer the core electron spin and nuclear spin quantum numbers.
8990
"""
9091
if species is not None:
9192
if isinstance(species, str):
93+
from rydstate.species import SpeciesObject # noqa: PLC0415
94+
9295
species = SpeciesObject.from_name(species)
9396
# use i_c = 0 for species without defined nuclear spin (-> ignore hyperfine)
9497
species_i_c = species.i_c if species.i_c is not None else 0
@@ -228,10 +231,8 @@ def to_state(self, coupling_scheme: CouplingScheme | None = None) -> AngularStat
228231
The angular state in the specified coupling scheme.
229232
230233
"""
231-
from rydstate.angular.angular_state import AngularState # noqa: PLC0415
232-
233234
if coupling_scheme is None or coupling_scheme == self.coupling_scheme:
234-
return AngularState([1], [self])
235+
return self._create_angular_state([1], [self])
235236
if coupling_scheme == "LS":
236237
return self._to_state_ls()
237238
if coupling_scheme == "JJ":
@@ -271,9 +272,7 @@ def _to_state_ls(self) -> AngularState[AngularKetLS]:
271272
kets.append(ls_ket)
272273
coefficients.append(coeff)
273274

274-
from rydstate.angular.angular_state import AngularState # noqa: PLC0415
275-
276-
return AngularState(coefficients, kets)
275+
return self._create_angular_state(coefficients, kets)
277276

278277
def _to_state_jj(self) -> AngularState[AngularKetJJ]:
279278
"""Convert a single ket to state in JJ coupling."""
@@ -306,9 +305,7 @@ def _to_state_jj(self) -> AngularState[AngularKetJJ]:
306305
kets.append(jj_ket)
307306
coefficients.append(coeff)
308307

309-
from rydstate.angular.angular_state import AngularState # noqa: PLC0415
310-
311-
return AngularState(coefficients, kets)
308+
return self._create_angular_state(coefficients, kets)
312309

313310
def _to_state_fj(self) -> AngularState[AngularKetFJ]:
314311
"""Convert a single ket to state in FJ coupling."""
@@ -341,6 +338,10 @@ def _to_state_fj(self) -> AngularState[AngularKetFJ]:
341338
kets.append(fj_ket)
342339
coefficients.append(coeff)
343340

341+
return self._create_angular_state(coefficients, kets)
342+
343+
def _create_angular_state(self, coefficients: Sequence[float], kets: Sequence[AngularKetBase]) -> AngularState[Any]:
344+
"""Create an AngularState from coefficients and kets."""
344345
from rydstate.angular.angular_state import AngularState # noqa: PLC0415
345346

346347
return AngularState(coefficients, kets)

src/rydstate/angular/angular_state.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from rydstate.angular.angular_matrix_element import AngularMomentumQuantumNumbers, AngularOperatorType
2323
from rydstate.angular.utils import CouplingScheme
24+
from rydstate.units import NDArray
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -30,7 +31,11 @@
3031

3132
class AngularState(Generic[_AngularKet]):
3233
def __init__(
33-
self, coefficients: Sequence[float], kets: Sequence[_AngularKet], *, warn_if_not_normalized: bool = True
34+
self,
35+
coefficients: Sequence[float] | NDArray,
36+
kets: Sequence[_AngularKet],
37+
*,
38+
warn_if_not_normalized: bool = True,
3439
) -> None:
3540
self.coefficients = np.array(coefficients)
3641
self.kets = kets

src/rydstate/basis/basis_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def filter_states(self, qn: str, value: float | tuple[float, float], *, delta: f
5050

5151
if is_angular_momentum_quantum_number(qn):
5252
self.states = [state for state in self.states if qn_min <= state.angular.calc_exp_qn(qn) <= qn_max]
53-
elif qn in ["n", "nu", "nu_energy"]:
53+
elif qn in ["n", "nu", "nu_ref"]:
5454
self.states = [state for state in self.states if qn_min <= getattr(state, qn) <= qn_max]
5555
else:
5656
raise ValueError(f"Unknown quantum number {qn}")
@@ -68,18 +68,18 @@ def sort_states(self, *qns: str) -> Self:
6868
self.states = [self.states[i] for i in sorted_indices]
6969
return self
7070

71-
def calc_exp_qn(self, qn: str) -> list[float]:
71+
def calc_exp_qn(self, qn: str) -> NDArray:
7272
if is_angular_momentum_quantum_number(qn):
73-
return [state.angular.calc_exp_qn(qn) for state in self.states]
74-
if qn in ["n", "nu", "nu_energy"]:
75-
return [getattr(state, qn) for state in self.states]
73+
return np.array([state.angular.calc_exp_qn(qn) for state in self.states])
74+
if qn in ["n", "nu", "nu_ref"]:
75+
return np.array([getattr(state, qn) for state in self.states])
7676
raise ValueError(f"Unknown quantum number {qn}")
7777

78-
def calc_std_qn(self, qn: str) -> list[float]:
78+
def calc_std_qn(self, qn: str) -> NDArray:
7979
if is_angular_momentum_quantum_number(qn):
80-
return [state.angular.calc_std_qn(qn) for state in self.states]
81-
if qn in ["n", "nu", "nu_energy"]:
82-
return [0 for state in self.states]
80+
return np.array([state.angular.calc_std_qn(qn) for state in self.states])
81+
if qn in ["n", "nu", "nu_ref"]:
82+
return np.zeros(len(self.states))
8383
raise ValueError(f"Unknown quantum number {qn}")
8484

8585
def calc_reduced_overlap(self, other: RydbergStateBase) -> NDArray:

src/rydstate/basis/basis_sqdt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
from typing import TYPE_CHECKING
45

56
import numpy as np
67

@@ -12,11 +13,14 @@
1213
RydbergStateSQDTAlkalineLS,
1314
)
1415

16+
if TYPE_CHECKING:
17+
from rydstate.species.species_object import SpeciesObject
18+
1519
logger = logging.getLogger(__name__)
1620

1721

1822
class BasisSQDTAlkali(BasisBase[RydbergStateSQDTAlkali]):
19-
def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None:
23+
def __init__(self, species: str | SpeciesObject, n_min: int = 1, n_max: int | None = None) -> None:
2024
super().__init__(species)
2125

2226
if n_max is None:
@@ -37,7 +41,7 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No
3741

3842

3943
class BasisSQDTAlkalineLS(BasisBase[RydbergStateSQDTAlkalineLS]):
40-
def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> None:
44+
def __init__(self, species: str | SpeciesObject, n_min: int = 1, n_max: int | None = None) -> None:
4145
super().__init__(species)
4246

4347
if n_max is None:
@@ -60,7 +64,7 @@ def __init__(self, species: str, n_min: int = 1, n_max: int | None = None) -> No
6064

6165

6266
class BasisSQDTAlkalineJJ(BasisBase[RydbergStateSQDTAlkalineJJ]):
63-
def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None:
67+
def __init__(self, species: str | SpeciesObject, n_min: int = 0, n_max: int | None = None) -> None:
6468
super().__init__(species)
6569

6670
if n_max is None:
@@ -90,7 +94,7 @@ def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> No
9094

9195

9296
class BasisSQDTAlkalineFJ(BasisBase[RydbergStateSQDTAlkalineFJ]):
93-
def __init__(self, species: str, n_min: int = 0, n_max: int | None = None) -> None:
97+
def __init__(self, species: str | SpeciesObject, n_min: int = 0, n_max: int | None = None) -> None:
9498
super().__init__(species)
9599

96100
if n_max is None:

src/rydstate/rydberg/rydberg_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,8 @@ def calc_reduced_overlap(self, other: RydbergStateBase) -> float: ...
2222
def calc_reduced_matrix_element(
2323
self, other: RydbergStateBase, operator: MatrixElementOperator, unit: str | None = None
2424
) -> PintFloat | float: ...
25+
26+
@property
27+
@abstractmethod
28+
def nu_ref(self) -> float:
29+
"""The reference effective principal quantum number nu_ref."""

src/rydstate/rydberg/rydberg_sqdt.py

Lines changed: 46 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from rydstate.units import BaseQuantities, MatrixElementOperatorRanks, ureg
1616

1717
if TYPE_CHECKING:
18+
from typing_extensions import Self
19+
1820
from rydstate.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS
1921
from rydstate.units import MatrixElementOperator, PintFloat
2022

@@ -94,14 +96,19 @@ def __init__(
9496
if nu is None and n is None:
9597
raise ValueError("Either n or nu must be given to initialize the Rydberg state.")
9698

99+
self._set_qn_as_attributes()
100+
101+
def _set_qn_as_attributes(self) -> None:
102+
pass
103+
97104
@classmethod
98105
def from_angular_ket(
99-
cls,
106+
cls: type[Self],
100107
species: str | SpeciesObject,
101108
angular_ket: AngularKetBase,
102109
n: int | None = None,
103110
nu: float | None = None,
104-
) -> RydbergStateSQDT:
111+
) -> Self:
105112
"""Initialize the Rydberg state from an angular ket."""
106113
obj = cls.__new__(cls)
107114

@@ -115,13 +122,14 @@ def from_angular_ket(
115122
raise ValueError("Either n or nu must be given to initialize the Rydberg state.")
116123

117124
obj.angular = angular_ket
125+
obj._set_qn_as_attributes() # noqa: SLF001
118126

119127
return obj
120128

121129
def __repr__(self) -> str:
122130
species, n, nu = self.species.name, self.n, self.nu
123131
n_str = f", {n=}" if n is not None else ""
124-
return f"{self.__class__.__name__}({species=}{n_str}, {nu=}, {self.angular})"
132+
return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {self.angular})"
125133

126134
def __str__(self) -> str:
127135
return self.__repr__()
@@ -151,13 +159,19 @@ def nu(self) -> float:
151159
"""The effective principal quantum number nu (for alkali atoms also known as n*) for the Rydberg state."""
152160
if self._nu is not None:
153161
return self._nu
154-
assert self.n is not None
155-
if any(qn not in self.angular.quantum_number_names for qn in ["j_tot", "s_tot"]):
156-
raise ValueError("j_tot and s_tot must be defined to calculate nu from n.")
162+
assert isinstance(self.species, SpeciesObject), "nu must be given if not sqdt"
163+
assert self.n is not None, "either nu or n must be given"
164+
165+
if "j_tot" not in self.angular.quantum_number_names or "s_tot" not in self.angular.quantum_number_names:
166+
raise RuntimeError("j_tot and s_tot must be defined in the angular ket to calculate nu from n.")
157167
return self.species.calc_nu(
158168
self.n, self.angular.l_r, self.angular.get_qn("j_tot"), s_tot=self.angular.get_qn("s_tot")
159169
)
160170

171+
@property
172+
def nu_ref(self) -> float:
173+
return self.nu
174+
161175
@overload
162176
def get_energy(self, unit: None = None) -> PintFloat: ...
163177

@@ -342,15 +356,18 @@ def __init__(
342356
"""
343357
super().__init__(species=species, n=n, nu=nu, l_r=l, j_tot=j, f_tot=f, m=m)
344358

345-
self.l = l
359+
def _set_qn_as_attributes(self) -> None:
360+
self.l = self.angular.l_r
346361
self.j = self.angular.j_tot
347362
self.f = self.angular.f_tot
348-
self.m = m
363+
self.m = self.angular.m
349364

350365
def __repr__(self) -> str:
351-
species, n, l, j, f, m = self.species, self.n, self.l, self.j, self.f, self.m
366+
species, n, nu = self.species.name, self.n, self.nu
367+
l, j, f, m = self.l, self.j, self.f, self.m
368+
n_str = f", {n=}" if n is not None else ""
352369
f_string = f", {f=}" if self.species.i_c not in (None, 0) else ""
353-
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j=}{f_string}, {m=})"
370+
return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {j=}{f_string}, {m=})"
354371

355372

356373
class RydbergStateSQDTAlkalineLS(RydbergStateSQDT):
@@ -388,15 +405,18 @@ def __init__(
388405
"""
389406
super().__init__(species=species, n=n, nu=nu, l_r=l, s_tot=s_tot, j_tot=j_tot, f_tot=f_tot, m=m)
390407

391-
self.l = l
408+
def _set_qn_as_attributes(self) -> None:
409+
self.l = self.angular.l_r
392410
self.s_tot = self.angular.s_tot
393411
self.j_tot = self.angular.j_tot
394412
self.f_tot = self.angular.f_tot
395-
self.m = m
413+
self.m = self.angular.m
396414

397415
def __repr__(self) -> str:
398-
species, n, l, s_tot, j_tot, f_tot, m = self.species, self.n, self.l, self.s_tot, self.j_tot, self.f_tot, self.m
399-
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {s_tot=}, {j_tot=}, {f_tot=}, {m=})"
416+
species, n, nu = self.species.name, self.n, self.nu
417+
l, s_tot, j_tot, f_tot, m = self.l, self.s_tot, self.j_tot, self.f_tot, self.m
418+
n_str = f", {n=}" if n is not None else ""
419+
return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {s_tot=}, {j_tot=}, {f_tot=}, {m=})"
400420

401421

402422
class RydbergStateSQDTAlkalineJJ(RydbergStateSQDT):
@@ -434,29 +454,18 @@ def __init__(
434454
"""
435455
super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, j_tot=j_tot, f_tot=f_tot, m=m)
436456

457+
def _set_qn_as_attributes(self) -> None:
437458
self.l = self.angular.l_r
438459
self.j_r = self.angular.j_r
439460
self.j_tot = self.angular.j_tot
440461
self.f_tot = self.angular.f_tot
441462
self.m = self.angular.m
442463

443464
def __repr__(self) -> str:
444-
species, n, l, j_r, j_tot, f_tot, m = self.species, self.n, self.l, self.j_r, self.j_tot, self.f_tot, self.m
445-
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})"
446-
447-
@cached_property
448-
def nu(self) -> float:
449-
if self._nu is not None:
450-
return self._nu
451-
assert self.n is not None
452-
nus = [self.species.calc_nu(self.n, self.l, self.j_tot, s_tot=s_tot) for s_tot in [0, 1]]
453-
454-
if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]):
455-
raise ValueError(
456-
"RydbergStateSQDTAlkalineJJ is intended for high-l states only, "
457-
"where the quantum defects are the same for singlet and triplet states."
458-
)
459-
return nus[0]
465+
species, n, nu = self.species.name, self.n, self.nu
466+
l, j_r, j_tot, f_tot, m = self.l, self.j_r, self.j_tot, self.f_tot, self.m
467+
n_str = f", {n=}" if n is not None else ""
468+
return f"{self.__class__.__name__}({species}{n_str}, {nu=}, {l=}, {j_r=}, {j_tot=}, {f_tot=}, {m=})"
460469

461470

462471
class RydbergStateSQDTAlkalineFJ(RydbergStateSQDT):
@@ -494,30 +503,17 @@ def __init__(
494503
"""
495504
super().__init__(species=species, n=n, nu=nu, l_r=l, j_r=j_r, f_c=f_c, f_tot=f_tot, m=m)
496505

506+
def _set_qn_as_attributes(self) -> None:
497507
self.l = self.angular.l_r
498508
self.j_r = self.angular.j_r
499509
self.f_c = self.angular.f_c
500510
self.f_tot = self.angular.f_tot
501511
self.m = self.angular.m
502512

503513
def __repr__(self) -> str:
504-
species, n, l, j_r, f_c, f_tot, m = self.species, self.n, self.l, self.j_r, self.f_c, self.f_tot, self.m
505-
return f"{self.__class__.__name__}({species.name}, {n=}, {l=}, {j_r=}, {f_c=}, {f_tot=}, {m=})"
506-
507-
@cached_property
508-
def nu(self) -> float:
509-
if self._nu is not None:
510-
return self._nu
511-
assert self.n is not None
512-
nus = [
513-
self.species.calc_nu(self.n, self.l, float(j_tot), s_tot=s_tot)
514-
for s_tot in [0, 1]
515-
for j_tot in np.arange(abs(self.j_r - 1 / 2), self.j_r + 1 / 2 + 1)
516-
]
517-
518-
if any(abs(nu - nus[0]) > 1e-10 for nu in nus[1:]):
519-
raise ValueError(
520-
"RydbergStateSQDTAlkalineFJ is intended for high-l states only, "
521-
"where the quantum defects are the same for singlet and triplet states."
522-
)
523-
return nus[0]
514+
species, n, nu = self.species.name, self.n, self.nu
515+
l, j_r, f_c, f_tot, m = self.l, self.j_r, self.f_c, self.f_tot, self.m
516+
l_c, j_c = self.angular.l_c, self.angular.j_c
517+
core_string = f", {l_c=}, {j_c=}" if l_c != 0 else ""
518+
n_str = f", {n=}" if n is not None else ""
519+
return f"{self.__class__.__name__}({species}{n_str}, {nu=}{core_string}, {l=}, {j_r=}, {f_c=}, {f_tot=}, {m=})"

0 commit comments

Comments
 (0)