Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.

Commit f541672

Browse files
improve angular ket
1 parent aeb8c3a commit f541672

3 files changed

Lines changed: 30 additions & 31 deletions

File tree

src/ryd_numerov/angular/angular_ket.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,14 @@ class AngularKetBase(ABC):
3737
"""Base class for a angular ket (i.e. a simple canonical spin ketstate)."""
3838

3939
# We use __slots__ to prevent dynamic attributes and make the objects immutable after initialization
40-
__slots__ = ("i_c", "s_c", "l_c", "s_r", "l_r", "f_tot", "m", "_initialized")
40+
__slots__ = ("i_c", "s_c", "l_c", "s_r", "l_r", "f_tot", "m", "quantum_numbers", "_initialized")
4141

42-
spin_quantum_number_names: ClassVar[set[AngularMomentumQuantumNumbers]]
42+
quantum_number_names: ClassVar[tuple[AngularMomentumQuantumNumbers, ...]]
4343
"""Names of all well defined spin quantum numbers (without the magnetic quantum number m) in this class."""
4444

45+
quantum_numbers: tuple[float, ...]
46+
"""The quantum numbers corresponding to the quantum_number_names (without the magnetic quantum number m)."""
47+
4548
coupled_quantum_numbers: ClassVar[
4649
dict[AngularMomentumQuantumNumbers, tuple[AngularMomentumQuantumNumbers, AngularMomentumQuantumNumbers]]
4750
]
@@ -109,6 +112,8 @@ def __init__(
109112
self.m = None if m is None else float(m)
110113

111114
def _post_init(self) -> None:
115+
self.quantum_numbers = tuple(getattr(self, qn) for qn in self.quantum_number_names)
116+
112117
self._initialized = True
113118

114119
self.sanity_check()
@@ -138,7 +143,7 @@ def __setattr__(self, key: str, value: object) -> None:
138143
super().__setattr__(key, value)
139144

140145
def __repr__(self) -> str:
141-
args = ", ".join(f"{k}={v}" for k, v in self.spin_quantum_numbers_dict.items())
146+
args = ", ".join(f"{qn}={val}" for qn, val in zip(self.quantum_number_names, self.quantum_numbers))
142147
if self.m is not None:
143148
args += f", m={self.m}"
144149
return f"{self.__class__.__name__}({args})"
@@ -153,24 +158,20 @@ def __eq__(self, other: object) -> bool:
153158
return False
154159
if self.m != other.m:
155160
return False
156-
return all(self.get_qn(q) == other.get_qn(q) for q in self.spin_quantum_numbers_dict)
161+
return self.quantum_numbers == other.quantum_numbers
157162

158163
def __hash__(self) -> int:
159164
return hash(
160165
(
161-
tuple((k, v) for k, v in self.spin_quantum_numbers_dict.items()),
166+
self.quantum_number_names,
167+
self.quantum_numbers,
162168
self.m,
163169
)
164170
)
165171

166-
@property
167-
def spin_quantum_numbers_dict(self) -> dict[AngularMomentumQuantumNumbers, float | int]:
168-
"""Return the spin quantum numbers (i.e. without the magnetic quantum number) as dictionary."""
169-
return {q: getattr(self, q) for q in self.spin_quantum_number_names}
170-
171172
def get_qn(self, qn: AngularMomentumQuantumNumbers) -> float:
172173
"""Get the value of a quantum number by name."""
173-
if qn not in self.spin_quantum_number_names:
174+
if qn not in self.quantum_number_names:
174175
raise ValueError(f"Quantum number {qn} not found in {self!r}.")
175176
return getattr(self, qn) # type: ignore [no-any-return]
176177

@@ -205,7 +206,7 @@ def calc_reduced_overlap(self, other: AngularKetBase) -> float:
205206
If the kets are of different types, the overlap is calculated using the corresponding
206207
Clebsch-Gordan coefficients (/ Wigner-j symbols).
207208
"""
208-
for q in self.spin_quantum_number_names & other.spin_quantum_number_names:
209+
for q in set(self.quantum_number_names) & set(other.quantum_number_names):
209210
if self.get_qn(q) != other.get_qn(q):
210211
return 0
211212

@@ -258,14 +259,14 @@ def calc_reduced_matrix_element( # noqa: C901
258259

259260
if type(self) is not type(other):
260261
return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa)
261-
if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.spin_quantum_number_names:
262+
if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.quantum_number_names:
262263
return self.to_state().calc_reduced_matrix_element(other.to_state(), operator, kappa)
263264

264265
qn_name: AngularMomentumQuantumNumbers
265266
if operator == "SPHERICAL":
266267
qn_name = "l_r"
267268
complete_reduced_matrix_element = calc_reduced_spherical_matrix_element(self.l_r, other.l_r, kappa)
268-
elif operator in self.spin_quantum_number_names:
269+
elif operator in self.quantum_number_names:
269270
if not kappa == 1:
270271
raise ValueError("Only kappa=1 is supported for spin operators.")
271272
qn_name = operator # type: ignore [assignment]
@@ -334,7 +335,7 @@ def _kronecker_delta_non_involved_spins(self, other: AngularKetBase, qn: Angular
334335
This means return 0 if any of the quantum numbers,
335336
that are not qn or a coupled quantum number resulting from qn differ between self and other.
336337
"""
337-
if qn not in self.spin_quantum_number_names:
338+
if qn not in self.quantum_number_names:
338339
raise ValueError(f"Quantum number {qn} is not a valid angular momentum quantum number for {self!r}.")
339340

340341
resulting_qns = {qn}
@@ -350,7 +351,7 @@ def _kronecker_delta_non_involved_spins(self, other: AngularKetBase, qn: Angular
350351
f"_kronecker_delta_non_involved_spins: {last_qn} not found in coupled_quantum_numbers."
351352
)
352353

353-
non_involved_qns = self.spin_quantum_number_names - resulting_qns
354+
non_involved_qns = set(self.quantum_number_names) - resulting_qns
354355
for _qn in non_involved_qns:
355356
if self.get_qn(_qn) != other.get_qn(_qn):
356357
return 0
@@ -396,7 +397,7 @@ class AngularKetLS(AngularKetBase):
396397
"""Spin ket in LS coupling."""
397398

398399
__slots__ = ("s_tot", "l_tot", "j_tot")
399-
spin_quantum_number_names: ClassVar = {"i_c", "s_c", "l_c", "s_r", "l_r", "s_tot", "l_tot", "j_tot", "f_tot"}
400+
quantum_number_names: ClassVar = ("i_c", "s_c", "l_c", "s_r", "l_r", "s_tot", "l_tot", "j_tot", "f_tot")
400401
coupled_quantum_numbers: ClassVar = {
401402
"s_tot": ("s_c", "s_r"),
402403
"l_tot": ("l_c", "l_r"),
@@ -518,7 +519,7 @@ class AngularKetJJ(AngularKetBase):
518519
"""Spin ket in JJ coupling."""
519520

520521
__slots__ = ("j_c", "j_r", "j_tot")
521-
spin_quantum_number_names: ClassVar = {"i_c", "s_c", "l_c", "s_r", "l_r", "j_c", "j_r", "j_tot", "f_tot"}
522+
quantum_number_names: ClassVar = ("i_c", "s_c", "l_c", "s_r", "l_r", "j_c", "j_r", "j_tot", "f_tot")
522523
coupled_quantum_numbers: ClassVar = {
523524
"j_c": ("s_c", "l_c"),
524525
"j_r": ("s_r", "l_r"),
@@ -652,7 +653,7 @@ class AngularKetFJ(AngularKetBase):
652653
"""Spin ket in FJ coupling."""
653654

654655
__slots__ = ("j_c", "f_c", "j_r")
655-
spin_quantum_number_names: ClassVar = {"i_c", "s_c", "l_c", "s_r", "l_r", "j_c", "f_c", "j_r", "f_tot"}
656+
quantum_number_names: ClassVar = ("i_c", "s_c", "l_c", "s_r", "l_r", "j_c", "f_c", "j_r", "f_tot")
656657
coupled_quantum_numbers: ClassVar = {
657658
"j_c": ("s_c", "l_c"),
658659
"f_c": ("i_c", "j_c"),

src/ryd_numerov/angular/angular_state.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ def calc_exp_qn(self, q: AngularMomentumQuantumNumbers) -> float:
9797
q: The quantum number to calculate the expectation value for.
9898
9999
"""
100-
if q not in self.kets[0].spin_quantum_number_names:
100+
if q not in self.kets[0].quantum_number_names:
101101
for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]:
102-
if q in ket_class.spin_quantum_number_names:
102+
if q in ket_class.quantum_number_names:
103103
return self._to_coupling_scheme(ket_class.coupling_scheme).calc_exp_qn(q)
104104

105105
qs = np.array([ket.get_qn(q) for ket in self.kets])
@@ -115,9 +115,9 @@ def calc_std_qn(self, q: AngularMomentumQuantumNumbers) -> float:
115115
q: The quantum number to calculate the standard deviation for.
116116
117117
"""
118-
if q not in self.kets[0].spin_quantum_number_names:
118+
if q not in self.kets[0].quantum_number_names:
119119
for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]:
120-
if q in ket_class.spin_quantum_number_names:
120+
if q in ket_class.quantum_number_names:
121121
return self._to_coupling_scheme(ket_class.coupling_scheme).calc_std_qn(q)
122122

123123
qs = np.array([ket.get_qn(q) for ket in self.kets])
@@ -156,12 +156,9 @@ def calc_reduced_matrix_element(
156156
"""
157157
if isinstance(other, AngularKetBase):
158158
other = other.to_state()
159-
if (
160-
operator in get_args(AngularMomentumQuantumNumbers)
161-
and operator not in self.kets[0].spin_quantum_number_names
162-
):
159+
if operator in get_args(AngularMomentumQuantumNumbers) and operator not in self.kets[0].quantum_number_names:
163160
for ket_class in [AngularKetLS, AngularKetJJ, AngularKetFJ]:
164-
if operator in ket_class.spin_quantum_number_names:
161+
if operator in ket_class.quantum_number_names:
165162
return self._to_coupling_scheme(ket_class.coupling_scheme).calc_reduced_matrix_element(
166163
other, operator, kappa
167164
)

tests/test_angular_matrix_elements.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,14 @@ def test_reduced_identity(ket: AngularKetBase) -> None:
7979
state_jj = ket.to_jj()
8080
state_fj = ket.to_fj()
8181

82-
for op in state_ls.kets[0].spin_quantum_number_names:
82+
op: AngularMomentumQuantumNumbers
83+
for op in state_ls.kets[0].quantum_number_names:
8384
assert np.isclose(reduced_identity, state_ls.calc_reduced_matrix_element(state_ls, "identity_" + op, kappa=0)) # type: ignore [arg-type]
8485

85-
for op in state_jj.kets[0].spin_quantum_number_names:
86+
for op in state_jj.kets[0].quantum_number_names:
8687
assert np.isclose(reduced_identity, state_jj.calc_reduced_matrix_element(state_jj, "identity_" + op, kappa=0)) # type: ignore [arg-type]
8788

88-
for op in state_fj.kets[0].spin_quantum_number_names:
89+
for op in state_fj.kets[0].quantum_number_names:
8990
assert np.isclose(reduced_identity, state_fj.calc_reduced_matrix_element(state_fj, "identity_" + op, kappa=0)) # type: ignore [arg-type]
9091

9192

0 commit comments

Comments
 (0)