|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -import math |
4 | | -from functools import lru_cache, wraps |
5 | | -from typing import TYPE_CHECKING, Callable, TypeVar |
6 | | - |
7 | 3 | import numpy as np |
8 | | -from sympy import Integer |
9 | | -from sympy.physics.wigner import ( |
10 | | - wigner_3j as sympy_wigner_3j, |
11 | | - wigner_6j as sympy_wigner_6j, |
12 | | - wigner_9j as sympy_wigner_9j, |
13 | | -) |
14 | | - |
15 | | -if TYPE_CHECKING: |
16 | | - from typing_extensions import ParamSpec |
17 | | - |
18 | | - P = ParamSpec("P") |
19 | | - R = TypeVar("R") |
20 | | - |
21 | | - def lru_cache(maxsize: int) -> Callable[[Callable[P, R]], Callable[P, R]]: ... # type: ignore [no-redef] |
22 | | - |
23 | | - |
24 | | -# global variables to possibly improve the performance of wigner j calculations |
25 | | -# in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs |
26 | | -CHECK_ARGS = True |
27 | | -USE_SYMMETRIES = False |
28 | | - |
29 | | - |
30 | | -def sympify_args(func: Callable[P, R]) -> Callable[P, R]: |
31 | | - """Check that quantum numbers are valid and convert to sympy.Integer (and half-integer).""" |
32 | | - if not CHECK_ARGS: |
33 | | - return func |
34 | | - |
35 | | - def check_arg(arg: float) -> Integer: |
36 | | - if isinstance(arg, int) or arg.is_integer(): |
37 | | - return Integer(int(arg)) |
38 | | - if isinstance(arg * 2, int) or (arg * 2).is_integer(): |
39 | | - return Integer(int(arg * 2)) / Integer(2) |
40 | | - raise ValueError(f"Invalid input to {func.__name__}: {arg}.") |
41 | | - |
42 | | - @wraps(func) |
43 | | - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: |
44 | | - _args = [check_arg(arg) for arg in args] # type: ignore[arg-type] |
45 | | - _kwargs = {key: check_arg(value) for key, value in kwargs.items()} # type: ignore[arg-type] |
46 | | - return func(*_args, **_kwargs) |
47 | | - |
48 | | - return wrapper |
49 | | - |
50 | | - |
51 | | -@lru_cache(maxsize=100_000) |
52 | | -@sympify_args |
53 | | -def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: |
54 | | - """Calculate the Wigner 3j symbol using lru_cache to improve performance.""" |
55 | | - return float(sympy_wigner_3j(j1, j2, j3, m1, m2, m3).evalf()) |
56 | | - |
57 | | - |
58 | | -@lru_cache(maxsize=100_000) |
59 | | -@sympify_args |
60 | | -def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: |
61 | | - """Calculate the Wigner 6j symbol using lru_cache to improve performance.""" |
62 | | - return float(sympy_wigner_6j(j1, j2, j3, j4, j5, j6).evalf()) |
63 | | - |
64 | | - |
65 | | -@lru_cache(maxsize=10_000) |
66 | | -@sympify_args |
67 | | -def calc_wigner_9j( |
68 | | - j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float |
69 | | -) -> float: |
70 | | - """Calculate the Wigner 9j symbol using lru_cache to improve performance.""" |
71 | | - return float(sympy_wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9).evalf()) |
72 | | - |
73 | | - |
74 | | -def clebsch_gordan_6j(j1: float, j2: float, j3: float, j12: float, j23: float, j_tot: float) -> float: |
75 | | - """Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. |
76 | | -
|
77 | | - We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". |
78 | | -
|
79 | | - See Also: |
80 | | - - https://en.wikipedia.org/wiki/Racah_W-coefficient |
81 | | - - https://en.wikipedia.org/wiki/6-j_symbol |
82 | | -
|
83 | | - Args: |
84 | | - j1: Spin quantum number 1. |
85 | | - j2: Spin quantum number 2. |
86 | | - j3: Spin quantum number 3. |
87 | | - j12: Total spin quantum number of j1 + j2. |
88 | | - j23: Total spin quantum number of j2 + j3. |
89 | | - j_tot: Total spin quantum number of j1 + j2 + j3. |
90 | | -
|
91 | | - Returns: |
92 | | - The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>. |
93 | | -
|
94 | | - """ |
95 | | - prefactor = minus_one_pow(j1 + j2 + j3 + j_tot) * math.sqrt((2 * j12 + 1) * (2 * j23 + 1)) |
96 | | - wigner_6j = calc_wigner_6j(j1, j2, j12, j3, j_tot, j23) |
97 | | - return prefactor * wigner_6j |
98 | | - |
99 | | - |
100 | | -def clebsch_gordan_9j( |
101 | | - j1: float, j2: float, j12: float, j3: float, j4: float, j34: float, j13: float, j24: float, j_tot: float |
102 | | -) -> float: |
103 | | - """Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. |
104 | | -
|
105 | | - We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics". |
106 | | -
|
107 | | - See Also: |
108 | | - - https://en.wikipedia.org/wiki/9-j_symbol |
109 | | -
|
110 | | - Args: |
111 | | - j1: Spin quantum number 1. |
112 | | - j2: Spin quantum number 2. |
113 | | - j12: Total spin quantum number of j1 + j2. |
114 | | - j3: Spin quantum number 1. |
115 | | - j4: Spin quantum number 2. |
116 | | - j34: Total spin quantum number of j1 + j2. |
117 | | - j13: Total spin quantum number of j1 + j3. |
118 | | - j24: Total spin quantum number of j2 + j4. |
119 | | - j_tot: Total spin quantum number of j1 + j2 + j3 + j4. |
120 | | -
|
121 | | - Returns: |
122 | | - The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>. |
123 | | -
|
124 | | - """ |
125 | | - prefactor = math.sqrt((2 * j12 + 1) * (2 * j34 + 1) * (2 * j13 + 1) * (2 * j24 + 1)) |
126 | | - return prefactor * calc_wigner_9j(j1, j2, j12, j3, j4, j34, j13, j24, j_tot) |
127 | | - |
128 | | - |
129 | | -def calc_wigner_3j_with_symmetries(j1: float, j2: float, j3: float, m1: float, m2: float, m3: float) -> float: |
130 | | - """Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached.""" |
131 | | - symmetry_factor: float = 1 |
132 | | - |
133 | | - # even permutation -> sort smallest j to be j1 |
134 | | - if j2 < j1 and j2 < j3: |
135 | | - j1, j2, j3, m1, m2, m3 = j2, j3, j1, m2, m3, m1 |
136 | | - elif j3 < j1 and j3 < j2: |
137 | | - j1, j2, j3, m1, m2, m3 = j3, j1, j2, m3, m1, m2 |
138 | | - |
139 | | - # odd permutation -> sort second smallest j to be j2 |
140 | | - if j3 < j2: |
141 | | - symmetry_factor *= minus_one_pow(j1 + j2 + j3) |
142 | | - j1, j2, j3, m1, m2, m3 = j1, j3, j2, m1, m3, m2 # noqa: PLW0127 |
143 | | - |
144 | | - # sign of m -> make m1 positive (or m2 if m1==0) |
145 | | - if m1 <= 0 or (m1 == 0 and m2 < 0): |
146 | | - symmetry_factor *= minus_one_pow(j1 + j2 + j3) |
147 | | - m1, m2, m3 = -m1, -m2, -m3 |
148 | | - |
149 | | - # TODO Regge symmetries |
150 | | - |
151 | | - return symmetry_factor * calc_wigner_3j(j1, j2, j3, m1, m2, m3) |
152 | | - |
153 | | - |
154 | | -def calc_wigner_6j_with_symmetries(j1: float, j2: float, j3: float, j4: float, j5: float, j6: float) -> float: |
155 | | - """Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached.""" |
156 | | - # interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5 |
157 | | - if j4 < j1: |
158 | | - j1, j2, j3, j4, j5, j6 = j4, j2, j6, j1, j5, j3 # noqa: PLW0127 |
159 | | - if j5 < j2: |
160 | | - j1, j2, j3, j4, j5, j6 = j1, j5, j6, j4, j2, j3 # noqa: PLW0127 |
161 | | - |
162 | | - # any permutation of columns -> make j1 <= j2 <= j3 |
163 | | - if j2 < j1 and j2 < j3: |
164 | | - j1, j2, j3, j4, j5, j6 = j2, j1, j3, j5, j4, j6 # noqa: PLW0127 |
165 | | - elif j3 < j1 and j3 < j2: |
166 | | - j1, j2, j3, j4, j5, j6 = j3, j2, j1, j6, j5, j4 # noqa: PLW0127 |
167 | | - |
168 | | - if j3 < j2: |
169 | | - j1, j2, j3, j4, j5, j6 = j1, j3, j2, j4, j6, j5 # noqa: PLW0127 |
170 | | - |
171 | | - return calc_wigner_6j(j1, j2, j3, j4, j5, j6) |
172 | | - |
173 | | - |
174 | | -def calc_wigner_9j_with_symmetries( |
175 | | - j1: float, j2: float, j3: float, j4: float, j5: float, j6: float, j7: float, j8: float, j9: float |
176 | | -) -> float: |
177 | | - """Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached.""" |
178 | | - symmetry_factor: float = 1 |
179 | | - js = [j1, j2, j3, j4, j5, j6, j7, j8, j9] |
180 | | - |
181 | | - # even permutation of rows and columns -> make smallest j to be j1 |
182 | | - min_j = min(js) |
183 | | - if min_j not in js[:3]: |
184 | | - if min_j in js[3:6]: |
185 | | - js = [*js[3:6], *js[6:9], *js[0:3]] |
186 | | - elif min_j in js[6:9]: |
187 | | - js = [*js[6:9], *js[0:3], *js[3:6]] |
188 | | - if js[0] != min_j: |
189 | | - if js[1] == min_j: |
190 | | - js = [js[1], js[2], js[0], js[4], js[5], js[3], js[7], js[8], js[6]] |
191 | | - elif js[2] == min_j: |
192 | | - js = [js[2], js[0], js[1], js[5], js[3], js[4], js[8], js[6], js[7]] |
193 | | - |
194 | | - # odd permutations of rows and columns-> make j2 <= j3 and j4 <= j7 |
195 | | - if js[2] < js[1]: |
196 | | - symmetry_factor *= minus_one_pow(sum(js)) |
197 | | - js = [js[0], js[2], js[1], js[3], js[5], js[4], js[6], js[8], js[7]] |
198 | | - if js[6] < js[3]: |
199 | | - symmetry_factor *= minus_one_pow(sum(js)) |
200 | | - js = [*js[0:3], *js[6:9], *js[3:6]] |
201 | | - |
202 | | - # reflection about diagonal -> make j2 <= j4 |
203 | | - if js[3] < js[1]: |
204 | | - js = [js[0], js[3], js[6], js[1], js[4], js[7], js[2], js[5], js[8]] |
205 | | - |
206 | | - return symmetry_factor * calc_wigner_9j(*js) |
207 | | - |
208 | | - |
209 | | -if USE_SYMMETRIES: |
210 | | - calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment] |
211 | | - calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment] |
212 | | - calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment] |
213 | 4 |
|
214 | 5 |
|
215 | 6 | def minus_one_pow(n: float) -> int: |
|
0 commit comments