Skip to content

Commit cf988ee

Browse files
speed up check is not set
1 parent 507ab83 commit cf988ee

5 files changed

Lines changed: 25 additions & 16 deletions

File tree

src/rydstate/angular/angular_ket.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
NotSet,
1818
check_spin_addition_rule,
1919
get_possible_quantum_number_values,
20+
is_not_set,
2021
minus_one_pow,
2122
try_trivial_spin_addition,
2223
)
@@ -115,7 +116,7 @@ def __init__(
115116
self.l_r = int(l_r)
116117

117118
# f_tot will be set in the subclasses
118-
self.m = NotSet if isinstance(m, NotSet) else float(m)
119+
self.m = NotSet if is_not_set(m) else float(m)
119120

120121
def _post_init(self) -> None:
121122
self.quantum_numbers = tuple(getattr(self, qn) for qn in self.quantum_number_names)
@@ -133,7 +134,7 @@ def sanity_check(self, msgs: list[str] | None = None) -> None:
133134
if self.s_r != 0.5:
134135
msgs.append(f"Rydberg electron spin s_r must be 1/2, but {self.s_r=}")
135136

136-
if not isinstance(self.m, NotSet) and not -self.f_tot <= self.m <= self.f_tot:
137+
if not is_not_set(self.m) and not -self.f_tot <= self.m <= self.f_tot:
137138
msgs.append(f"m must be between -f_tot and f_tot, but {self.f_tot=}, {self.m=}")
138139

139140
if msgs:
@@ -150,7 +151,7 @@ def __setattr__(self, key: str, value: object) -> None:
150151

151152
def __repr__(self) -> str:
152153
args = ", ".join(f"{qn}={val}" for qn, val in zip(self.quantum_number_names, self.quantum_numbers, strict=True))
153-
if not isinstance(self.m, NotSet):
154+
if not is_not_set(self.m):
154155
args += f", m={self.m}"
155156
return f"{self.__class__.__name__}({args})"
156157

@@ -467,16 +468,16 @@ def calc_matrix_element(self, other: AngularKetBase, operator: AngularOperatorTy
467468
The dimensionless angular matrix element.
468469
469470
"""
470-
if isinstance(self.m, NotSet) or isinstance(other.m, NotSet):
471-
raise RuntimeError("m must be set to calculate the matrix element.") # noqa: TRY004
471+
if is_not_set(self.m) or is_not_set(other.m):
472+
raise RuntimeError("m must be set to calculate the matrix element.")
472473

473474
prefactor = self._calc_wigner_eckart_prefactor(other, kappa, q)
474475
reduced_matrix_element = self.calc_reduced_matrix_element(other, operator, kappa)
475476
return prefactor * reduced_matrix_element
476477

477478
def _calc_wigner_eckart_prefactor(self, other: AngularKetBase, kappa: int, q: int) -> float:
478-
if isinstance(self.m, NotSet) or isinstance(other.m, NotSet):
479-
raise RuntimeError("m must be set to calculate the Wigner-Eckart prefactor.") # noqa: TRY004
479+
if is_not_set(self.m) or is_not_set(other.m):
480+
raise RuntimeError("m must be set to calculate the Wigner-Eckart prefactor.")
480481
return minus_one_pow(self.f_tot - self.m) * calc_wigner_3j(self.f_tot, kappa, other.f_tot, -self.m, q, other.m)
481482

482483
def _kronecker_delta_non_involved_spins(self, other: AngularKetBase, qn: AngularMomentumQuantumNumbers) -> int:

src/rydstate/angular/angular_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
AngularKetLS,
1414
)
1515
from rydstate.angular.angular_matrix_element import is_angular_momentum_quantum_number
16-
from rydstate.angular.utils import NotSet
16+
from rydstate.angular.utils import is_not_set
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Iterator, Sequence
@@ -210,8 +210,8 @@ def calc_matrix_element(
210210
"Different m values are not supported yet for AngularState.calc_matrix_element."
211211
)
212212

213-
if isinstance(self.kets[0].m, NotSet) or isinstance(other.kets[0].m, NotSet):
214-
raise RuntimeError("m must be set for all kets to calculate the matrix element.") # noqa: TRY004
213+
if is_not_set(self.kets[0].m) or is_not_set(other.kets[0].m):
214+
raise RuntimeError("m must be set for all kets to calculate the matrix element.")
215215

216216
prefactor = self.kets[0]._calc_wigner_eckart_prefactor(other.kets[0], kappa, q) # noqa: SLF001
217217
reduced_matrix_element = self.calc_reduced_matrix_element(other, operator, kappa)

src/rydstate/angular/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import contextlib
44
import typing as t
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import TYPE_CHECKING, Any, Literal
66

77
import numpy as np
88

99
if TYPE_CHECKING:
10+
from typing_extensions import TypeIs
11+
1012
from rydstate.angular.angular_ket import AngularKetBase
1113
from rydstate.species.species_object import SpeciesObject
1214

@@ -26,6 +28,11 @@ class NotSet(t.Protocol):
2628
def __not_set() -> None: ...
2729

2830

31+
def is_not_set(obj: Any) -> TypeIs[NotSet]: # noqa: ANN401
32+
"""Check if the obj is the NotSet singleton."""
33+
return obj is NotSet
34+
35+
2936
class InvalidQuantumNumbersError(ValueError):
3037
def __init__(self, ket: AngularKetBase, msg: str = "") -> None:
3138
_msg = f"Invalid quantum numbers for {ket!r}"

src/rydstate/basis/basis_sqdt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import numpy as np
77

8-
from rydstate.angular.utils import NotSet
8+
from rydstate.angular import NotSet
9+
from rydstate.angular.utils import is_not_set
910
from rydstate.basis.basis_base import BasisBase
1011
from rydstate.rydberg import (
1112
RydbergStateSQDT,
@@ -23,7 +24,7 @@
2324

2425
class BasisSQDT(BasisBase[_RydbergStateSQDT]):
2526
def _get_m_range(self, m: tuple[float, float] | None | NotSet, f_tot: float | np.floating) -> list[NotSet | float]:
26-
if isinstance(m, NotSet):
27+
if is_not_set(m):
2728
return [NotSet]
2829
if m is None:
2930
m = (-np.inf, np.inf)

src/rydstate/rydberg/rydberg_sqdt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy.special import exprel
1010

1111
from rydstate.angular import NotSet
12-
from rydstate.angular.utils import quantum_numbers_to_angular_ket
12+
from rydstate.angular.utils import is_not_set, quantum_numbers_to_angular_ket
1313
from rydstate.radial import RadialKet
1414
from rydstate.rydberg.rydberg_base import RydbergStateBase
1515
from rydstate.species import SpeciesObjectSQDT
@@ -428,8 +428,8 @@ def _get_transition_rates_au(
428428
basis_class = BasisSQDTAlkali if self.species.number_valence_electrons == 1 else BasisSQDTAlkalineLS
429429

430430
m = self.angular.m
431-
if isinstance(m, NotSet):
432-
raise RuntimeError("m quantum number must be defined to calculate transition rates.") # noqa: TRY004
431+
if is_not_set(m):
432+
raise RuntimeError("m quantum number must be defined to calculate transition rates.")
433433

434434
basis = basis_class(self.species, n=(1, int(self.nu + 35)), m=(m - 1, m + 1))
435435
basis.filter_states("l_r", (self.angular.l_r - 1, self.angular.l_r + 1))

0 commit comments

Comments
 (0)