Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.

Commit cdd376d

Browse files
improve angular utils and angular init
1 parent c1981bc commit cdd376d

6 files changed

Lines changed: 104 additions & 115 deletions

File tree

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,9 @@
1-
from ryd_numerov.angular.angular_ket import AngularKetBase, AngularKetFJ, AngularKetJJ, AngularKetLS
2-
from ryd_numerov.angular.angular_matrix_element import (
3-
calc_prefactor_of_operator_in_coupled_scheme,
4-
calc_reduced_spherical_matrix_element,
5-
calc_reduced_spin_matrix_element,
6-
)
1+
from ryd_numerov.angular.angular_ket import AngularKetFJ, AngularKetJJ, AngularKetLS
72
from ryd_numerov.angular.angular_state import AngularState
8-
from ryd_numerov.angular.utils import (
9-
calc_wigner_3j,
10-
calc_wigner_6j,
11-
calc_wigner_9j,
12-
clebsch_gordan_6j,
13-
clebsch_gordan_9j,
14-
)
153

164
__all__ = [
17-
"AngularKetBase",
185
"AngularKetFJ",
196
"AngularKetJJ",
207
"AngularKetLS",
218
"AngularState",
22-
"calc_prefactor_of_operator_in_coupled_scheme",
23-
"calc_reduced_spherical_matrix_element",
24-
"calc_reduced_spin_matrix_element",
25-
"calc_wigner_3j",
26-
"calc_wigner_6j",
27-
"calc_wigner_9j",
28-
"clebsch_gordan_6j",
29-
"clebsch_gordan_9j",
309
]

src/ryd_numerov/angular/angular_ket.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
clebsch_gordan_6j,
1919
clebsch_gordan_9j,
2020
get_possible_quantum_number_list,
21+
minus_one_pow,
2122
try_trivial_spin_addition,
2223
)
2324
from ryd_numerov.elements import BaseElement
@@ -449,7 +450,7 @@ def calc_matrix_element(self, other: AngularKetBase, operator: AngularOperatorTy
449450

450451
def _calc_wigner_eckart_prefactor(self, other: AngularKetBase, kappa: int, q: int) -> float:
451452
assert self.m is not None and other.m is not None, "m must be set to calculate the Wigner-Eckart prefactor." # noqa: PT018
452-
return (-1) ** (self.f_tot - self.m) * calc_wigner_3j(self.f_tot, kappa, other.f_tot, -self.m, q, other.m) # type: ignore [return-value]
453+
return minus_one_pow(self.f_tot - self.m) * calc_wigner_3j(self.f_tot, kappa, other.f_tot, -self.m, q, other.m)
453454

454455
def _kronecker_delta_non_involved_spins(self, other: AngularKetBase, qn: AngularMomentumQuantumNumbers) -> int:
455456
"""Calculate the Kronecker delta for non involved angular momentum quantum numbers.

src/ryd_numerov/angular/angular_matrix_element.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def calc_reduced_spherical_matrix_element(l_r_final: int, l_r_initial: int, kapp
7676
def calc_reduced_spin_matrix_element(s_final: float, s_initial: float) -> float:
7777
r"""Calculate the reduced spin matrix element (s_final || \hat{s} || s_initial).
7878
79-
The spin operator \hat{s} must be the operator corresponding to the quantum number s_final and s_initial.
79+
The spin operator \hat{s} can be any of the AngularMomentumQuantumNumbers,
80+
but must be corresponding to the given quantum number s_final and s_initial.
8081
8182
We follow the convention of equation (5.4.3) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
8283
The matrix elements of the spin operators are given by:

src/ryd_numerov/angular/utils.py

Lines changed: 97 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... # type: ignore [no-redef]
2222

2323

24+
# global variables to possibly improve the performance of wigner j calculations
25+
# in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs
26+
CHECK_ARGS = True
2427
USE_SYMMETRIES = False
2528

2629

2730
def sympify_args(func: Callable[P, R]) -> Callable[P, R]:
2831
"""Check that quantum numbers are valid and convert to sympy.Integer (and half-integer)."""
32+
if not CHECK_ARGS:
33+
return func
2934

3035
def check_arg(arg: float) -> Integer:
3136
if arg.is_integer():
@@ -43,8 +48,86 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
4348
return wrapper
4449

4550

51+
@lru_cache(maxsize=10_000)
52+
@sympify_args
4653
def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float:
47-
"""Calculate the Wigner 3j symbol using symmetries and lru_cache to improve performance."""
54+
"""Calculate the Wigner 3j symbol using lru_cache to improve performance."""
55+
return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf())
56+
57+
58+
@lru_cache(maxsize=100_000)
59+
@sympify_args
60+
def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
61+
"""Calculate the Wigner 6j symbol using lru_cache to improve performance."""
62+
return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf())
63+
64+
65+
@lru_cache(maxsize=10_000)
66+
@sympify_args
67+
def calc_wigner_9j(
68+
j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float
69+
) -> float:
70+
"""Calculate the Wigner 9j symbol using lru_cache to improve performance."""
71+
return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf())
72+
73+
74+
def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float:
75+
"""Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
76+
77+
We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
78+
79+
See Also:
80+
- https://en.wikipedia.org/wiki/Racah_W-coefficient
81+
- https://en.wikipedia.org/wiki/6-j_symbol
82+
83+
Args:
84+
j1: Spin quantum number 1.
85+
j2: Spin quantum number 2.
86+
j3: Spin quantum number 3.
87+
j12: Total spin quantum number of j1 + j2.
88+
j23: Total spin quantum number of j2 + j3.
89+
j_tot: Total spin quantum number of j1 + j2 + j3.
90+
91+
Returns:
92+
The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
93+
94+
"""
95+
prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1))
96+
wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23)
97+
return prefactor * wigner_6j
98+
99+
100+
def clebsch_gordan_9j(
101+
j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float
102+
) -> float:
103+
"""Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
104+
105+
We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
106+
107+
See Also:
108+
- https://en.wikipedia.org/wiki/9-j_symbol
109+
110+
Args:
111+
j1: Spin quantum number 1.
112+
j2: Spin quantum number 2.
113+
j12: Total spin quantum number of j1 + j2.
114+
j3: Spin quantum number 1.
115+
j4: Spin quantum number 2.
116+
j34: Total spin quantum number of j1 + j2.
117+
j13: Total spin quantum number of j1 + j3.
118+
j24: Total spin quantum number of j2 + j4.
119+
j_tot: Total spin quantum number of j1 + j2 + j3 + j4.
120+
121+
Returns:
122+
The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
123+
124+
"""
125+
prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1))
126+
return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot)
127+
128+
129+
def calc_wigner_3j_with_symmetries(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float:
130+
"""Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached."""
48131
symmetry_factor: float = 1
49132

50133
# even permutation -> sort smallest j to be j1
@@ -65,17 +148,11 @@ def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: fl
65148

66149
# TODO Regge symmetries
67150

68-
return symmetry_factor * _calc_wigner_3j(j1, j2, j3, m1, m2, m3)
151+
return symmetry_factor * calc_wigner_3j(j1, j2, j3, m1, m2, m3)
69152

70153

71-
@lru_cache(maxsize=10_000)
72-
@sympify_args
73-
def _calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float:
74-
return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf())
75-
76-
77-
def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
78-
"""Calculate the Wigner 6j symbol using symmetries and lru_cache to improve performance."""
154+
def calc_wigner_6j_with_symmetries(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
155+
"""Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached."""
79156
# interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5
80157
if j4 < j1:
81158
j1, j2, j3, j4, j5, j6 = j4, j2, j6, j1, j5, j3 # noqa: PLW0127
@@ -91,19 +168,13 @@ def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: fl
91168
if j3 < j2:
92169
j1, j2, j3, j4, j5, j6 = j1, j3, j2, j4, j6, j5 # noqa: PLW0127
93170

94-
return _calc_wigner_6j(j1, j2, j3, j4, j5, j6)
95-
171+
return calc_wigner_6j(j1, j2, j3, j4, j5, j6)
96172

97-
@lru_cache(maxsize=100_000)
98-
@sympify_args
99-
def _calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
100-
return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf())
101173

102-
103-
def calc_wigner_9j(
174+
def calc_wigner_9j_with_symmetries(
104175
j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float
105176
) -> float:
106-
"""Calculate the Wigner 9j symbol using symmetries and lru_cache to improve performance."""
177+
"""Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached."""
107178
symmetry_factor: float = 1
108179
js = [j1, j2, j3, j4, j5, j6, j7, j8, j9]
109180

@@ -132,78 +203,22 @@ def calc_wigner_9j(
132203
if js[3] < js[1]:
133204
js = [js[0], js[3], js[6], js[1], js[4], js[7], js[2], js[5], js[8]]
134205

135-
return symmetry_factor * _calc_wigner_9j(*js)
136-
137-
138-
@lru_cache(maxsize=10_000)
139-
@sympify_args
140-
def _calc_wigner_9j(
141-
j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float
142-
) -> float:
143-
return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf())
144-
145-
146-
def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float:
147-
"""Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
148-
149-
We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
150-
151-
See Also:
152-
- https://en.wikipedia.org/wiki/Racah_W-coefficient
153-
- https://en.wikipedia.org/wiki/6-j_symbol
154-
155-
Args:
156-
j1: Spin quantum number 1.
157-
j2: Spin quantum number 2.
158-
j3: Spin quantum number 3.
159-
j12: Total spin quantum number of j1 + j2.
160-
j23: Total spin quantum number of j2 + j3.
161-
j_tot: Total spin quantum number of j1 + j2 + j3.
162-
163-
Returns:
164-
The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
206+
return symmetry_factor * calc_wigner_9j(*js)
165207

166-
"""
167-
prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1))
168-
wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23)
169-
return prefactor * wigner_6j
170208

171-
172-
def clebsch_gordan_9j(
173-
j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float
174-
) -> float:
175-
"""Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
176-
177-
We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
178-
179-
See Also:
180-
- https://en.wikipedia.org/wiki/9-j_symbol
181-
182-
Args:
183-
j1: Spin quantum number 1.
184-
j2: Spin quantum number 2.
185-
j12: Total spin quantum number of j1 + j2.
186-
j3: Spin quantum number 1.
187-
j4: Spin quantum number 2.
188-
j34: Total spin quantum number of j1 + j2.
189-
j13: Total spin quantum number of j1 + j3.
190-
j24: Total spin quantum number of j2 + j4.
191-
j_tot: Total spin quantum number of j1 + j2 + j3 + j4.
192-
193-
Returns:
194-
The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
195-
196-
"""
197-
prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1))
198-
return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot)
209+
if USE_SYMMETRIES:
210+
calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment]
211+
calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment]
212+
calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment]
199213

200214

201215
def minus_one_pow(n: float) -> int:
216+
"""Calculate (-1)^n for an integer n and raise an error if n is not an integer."""
202217
if n % 2 == 0:
203218
return 1
204219
if n % 2 == 1:
205220
return -1
206-
raise ValueError(f"Invalid input {n}.")
221+
raise ValueError(f"minus_one_pow: Invalid input {n=} is not an integer.")
207222

208223

209224
def try_trivial_spin_addition(s_1: float, s_2: float, s_tot: float | None, name: str) -> float:
@@ -235,9 +250,3 @@ def get_possible_quantum_number_list(s_1: float, s_2: float, s_tot: float | None
235250
if s_tot is not None:
236251
return [s_tot]
237252
return [float(s) for s in np.arange(abs(s_1 - s_2), s_1 + s_2 + 1, 1)]
238-
239-
240-
if not USE_SYMMETRIES:
241-
calc_wigner_3j = _calc_wigner_3j
242-
calc_wigner_6j = _calc_wigner_6j
243-
calc_wigner_9j = _calc_wigner_9j

src/ryd_numerov/rydberg_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
if TYPE_CHECKING:
1818
from typing_extensions import Self
1919

20-
from ryd_numerov.angular import AngularKetBase
20+
from ryd_numerov.angular.angular_ket import AngularKetBase
2121
from ryd_numerov.units import PintFloat
2222

2323

tests/test_angular_matrix_elements.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from ryd_numerov.angular.angular_matrix_element import AngularMomentumQuantumNumbers
99

1010
if TYPE_CHECKING:
11-
from ryd_numerov.angular import AngularKetBase
12-
from ryd_numerov.angular.angular_ket import CouplingScheme
11+
from ryd_numerov.angular.angular_ket import AngularKetBase, CouplingScheme
1312
from ryd_numerov.angular.angular_matrix_element import AngularOperatorType
1413

1514
TEST_KET_PAIRS = [

0 commit comments

Comments
 (0)