Skip to content

Commit 10e59a1

Browse files
split into wigner symbols and utils
1 parent 18f1362 commit 10e59a1

4 files changed

Lines changed: 216 additions & 213 deletions

File tree

src/rydstate/angular/angular_ket.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
is_angular_operator_type,
1515
)
1616
from rydstate.angular.utils import (
17-
calc_wigner_3j,
1817
check_spin_addition_rule,
19-
clebsch_gordan_6j,
20-
clebsch_gordan_9j,
2118
get_possible_quantum_number_values,
2219
minus_one_pow,
2320
try_trivial_spin_addition,
2421
)
22+
from rydstate.angular.wigner_symbols import calc_wigner_3j, clebsch_gordan_6j, clebsch_gordan_9j
2523
from rydstate.species import SpeciesObject
2624

2725
if TYPE_CHECKING:

src/rydstate/angular/angular_matrix_element.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
from typing_extensions import TypeGuard
99

10-
from rydstate.angular.utils import calc_wigner_3j, calc_wigner_6j, minus_one_pow
10+
from rydstate.angular.utils import minus_one_pow
11+
from rydstate.angular.wigner_symbols import calc_wigner_3j, calc_wigner_6j
1112

1213
if TYPE_CHECKING:
1314
from typing_extensions import ParamSpec

src/rydstate/angular/utils.py

Lines changed: 0 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,215 +1,6 @@
11
from __future__ import annotations
22

3-
import math
4-
from functools import lru_cache, wraps
5-
from typing import TYPE_CHECKING, Callable, TypeVar
6-
73
import numpy as np
8-
from sympy import Integer
9-
from sympy.physics.wigner import (
10-
wigner_3j as sympy_wigner_3j,
11-
wigner_6j as sympy_wigner_6j,
12-
wigner_9j as sympy_wigner_9j,
13-
)
14-
15-
if TYPE_CHECKING:
16-
from typing_extensions import ParamSpec
17-
18-
P = ParamSpec("P")
19-
R = TypeVar("R")
20-
21-
def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... # type: ignore [no-redef]
22-
23-
24-
# global variables to possibly improve the performance of wigner j calculations
25-
# in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs
26-
CHECK_ARGS = True
27-
USE_SYMMETRIES = False
28-
29-
30-
def sympify_args(func: Callable[P, R]) -> Callable[P, R]:
31-
"""Check that quantum numbers are valid and convert to sympy.Integer (and half-integer)."""
32-
if not CHECK_ARGS:
33-
return func
34-
35-
def check_arg(arg: float) -> Integer:
36-
if isinstance(arg, int) or arg.is_integer():
37-
return Integer(int(arg))
38-
if isinstance(arg * 2, int) or (arg * 2).is_integer():
39-
return Integer(int(arg * 2)) / Integer(2)
40-
raise ValueError(f"Invalid input to {func.__name__}: {arg}.")
41-
42-
@wraps(func)
43-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
44-
_args = [check_arg(arg) for arg in args] # type: ignore[arg-type]
45-
_kwargs = {key: check_arg(value) for key, value in kwargs.items()} # type: ignore[arg-type]
46-
return func(*_args, **_kwargs)
47-
48-
return wrapper
49-
50-
51-
@lru_cache(maxsize=100_000)
52-
@sympify_args
53-
def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float:
54-
"""Calculate the Wigner 3j symbol using lru_cache to improve performance."""
55-
return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf())
56-
57-
58-
@lru_cache(maxsize=100_000)
59-
@sympify_args
60-
def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
61-
"""Calculate the Wigner 6j symbol using lru_cache to improve performance."""
62-
return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf())
63-
64-
65-
@lru_cache(maxsize=10_000)
66-
@sympify_args
67-
def calc_wigner_9j(
68-
j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float
69-
) -> float:
70-
"""Calculate the Wigner 9j symbol using lru_cache to improve performance."""
71-
return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf())
72-
73-
74-
def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float:
75-
"""Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
76-
77-
We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
78-
79-
See Also:
80-
- https://en.wikipedia.org/wiki/Racah_W-coefficient
81-
- https://en.wikipedia.org/wiki/6-j_symbol
82-
83-
Args:
84-
j1: Spin quantum number 1.
85-
j2: Spin quantum number 2.
86-
j3: Spin quantum number 3.
87-
j12: Total spin quantum number of j1 + j2.
88-
j23: Total spin quantum number of j2 + j3.
89-
j_tot: Total spin quantum number of j1 + j2 + j3.
90-
91-
Returns:
92-
The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
93-
94-
"""
95-
prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1))
96-
wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23)
97-
return prefactor * wigner_6j
98-
99-
100-
def clebsch_gordan_9j(
101-
j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float
102-
) -> float:
103-
"""Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
104-
105-
We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
106-
107-
See Also:
108-
- https://en.wikipedia.org/wiki/9-j_symbol
109-
110-
Args:
111-
j1: Spin quantum number 1.
112-
j2: Spin quantum number 2.
113-
j12: Total spin quantum number of j1 + j2.
114-
j3: Spin quantum number 1.
115-
j4: Spin quantum number 2.
116-
j34: Total spin quantum number of j1 + j2.
117-
j13: Total spin quantum number of j1 + j3.
118-
j24: Total spin quantum number of j2 + j4.
119-
j_tot: Total spin quantum number of j1 + j2 + j3 + j4.
120-
121-
Returns:
122-
The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
123-
124-
"""
125-
prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1))
126-
return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot)
127-
128-
129-
def calc_wigner_3j_with_symmetries(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float:
130-
"""Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached."""
131-
symmetry_factor: float = 1
132-
133-
# even permutation -> sort smallest j to be j1
134-
if j2 < j1 and j2 < j3:
135-
j1, j2, j3, m1, m2, m3 = j2, j3, j1, m2, m3, m1
136-
elif j3 < j1 and j3 < j2:
137-
j1, j2, j3, m1, m2, m3 = j3, j1, j2, m3, m1, m2
138-
139-
# odd permutation -> sort second smallest j to be j2
140-
if j3 < j2:
141-
symmetry_factor *= minus_one_pow(j1 + j2 + j3)
142-
j1, j2, j3, m1, m2, m3 = j1, j3, j2, m1, m3, m2 # noqa: PLW0127
143-
144-
# sign of m -> make m1 positive (or m2 if m1==0)
145-
if m1 <= 0 or (m1 == 0 and m2 < 0):
146-
symmetry_factor *= minus_one_pow(j1 + j2 + j3)
147-
m1, m2, m3 = -m1, -m2, -m3
148-
149-
# TODO Regge symmetries
150-
151-
return symmetry_factor * calc_wigner_3j(j1, j2, j3, m1, m2, m3)
152-
153-
154-
def calc_wigner_6j_with_symmetries(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float:
155-
"""Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached."""
156-
# interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5
157-
if j4 < j1:
158-
j1, j2, j3, j4, j5, j6 = j4, j2, j6, j1, j5, j3 # noqa: PLW0127
159-
if j5 < j2:
160-
j1, j2, j3, j4, j5, j6 = j1, j5, j6, j4, j2, j3 # noqa: PLW0127
161-
162-
# any permutation of columns -> make j1 <= j2 <= j3
163-
if j2 < j1 and j2 < j3:
164-
j1, j2, j3, j4, j5, j6 = j2, j1, j3, j5, j4, j6 # noqa: PLW0127
165-
elif j3 < j1 and j3 < j2:
166-
j1, j2, j3, j4, j5, j6 = j3, j2, j1, j6, j5, j4 # noqa: PLW0127
167-
168-
if j3 < j2:
169-
j1, j2, j3, j4, j5, j6 = j1, j3, j2, j4, j6, j5 # noqa: PLW0127
170-
171-
return calc_wigner_6j(j1, j2, j3, j4, j5, j6)
172-
173-
174-
def calc_wigner_9j_with_symmetries(
175-
j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float
176-
) -> float:
177-
"""Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached."""
178-
symmetry_factor: float = 1
179-
js = [j1, j2, j3, j4, j5, j6, j7, j8, j9]
180-
181-
# even permutation of rows and columns -> make smallest j to be j1
182-
min_j = min(js)
183-
if min_j not in js[:3]:
184-
if min_j in js[3:6]:
185-
js = [*js[3:6], *js[6:9], *js[0:3]]
186-
elif min_j in js[6:9]:
187-
js = [*js[6:9], *js[0:3], *js[3:6]]
188-
if js[0] != min_j:
189-
if js[1] == min_j:
190-
js = [js[1], js[2], js[0], js[4], js[5], js[3], js[7], js[8], js[6]]
191-
elif js[2] == min_j:
192-
js = [js[2], js[0], js[1], js[5], js[3], js[4], js[8], js[6], js[7]]
193-
194-
# odd permutations of rows and columns-> make j2 <= j3 and j4 <= j7
195-
if js[2] < js[1]:
196-
symmetry_factor *= minus_one_pow(sum(js))
197-
js = [js[0], js[2], js[1], js[3], js[5], js[4], js[6], js[8], js[7]]
198-
if js[6] < js[3]:
199-
symmetry_factor *= minus_one_pow(sum(js))
200-
js = [*js[0:3], *js[6:9], *js[3:6]]
201-
202-
# reflection about diagonal -> make j2 <= j4
203-
if js[3] < js[1]:
204-
js = [js[0], js[3], js[6], js[1], js[4], js[7], js[2], js[5], js[8]]
205-
206-
return symmetry_factor * calc_wigner_9j(*js)
207-
208-
209-
if USE_SYMMETRIES:
210-
calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment]
211-
calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment]
212-
calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment]
2134

2145

2156
def minus_one_pow(n: float) -> int:

0 commit comments

Comments
 (0)