2121 def lru_cache (maxsize : int ) -> Callable [[Callable [P , R ]], Callable [P , R ]]: ... # type: ignore [no-redef]
2222
2323
24+ # global variables to possibly improve the performance of wigner j calculations
25+ # in the public release we will always use CHECK_ARGS = True and USE_SYMMETRIES = False to reduce potential of bugs
26+ CHECK_ARGS = True
2427USE_SYMMETRIES = False
2528
2629
2730def sympify_args (func : Callable [P , R ]) -> Callable [P , R ]:
2831 """Check that quantum numbers are valid and convert to sympy.Integer (and half-integer)."""
32+ if not CHECK_ARGS :
33+ return func
2934
3035 def check_arg (arg : float ) -> Integer :
3136 if arg .is_integer ():
@@ -43,8 +48,86 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
4348 return wrapper
4449
4550
51+ @lru_cache (maxsize = 10_000 )
52+ @sympify_args
4653def calc_wigner_3j (j1 : float , j2 : float , j3 : float , m1 : float , m2 : float , m3 : float ) -> float :
47- """Calculate the Wigner 3j symbol using symmetries and lru_cache to improve performance."""
54+ """Calculate the Wigner 3j symbol using lru_cache to improve performance."""
55+ return float (sympy_wigner_3j (j1 , j2 , j3 , m1 , m2 , m3 ).evalf ())
56+
57+
58+ @lru_cache (maxsize = 100_000 )
59+ @sympify_args
60+ def calc_wigner_6j (j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float ) -> float :
61+ """Calculate the Wigner 6j symbol using lru_cache to improve performance."""
62+ return float (sympy_wigner_6j (j1 , j2 , j3 , j4 , j5 , j6 ).evalf ())
63+
64+
65+ @lru_cache (maxsize = 10_000 )
66+ @sympify_args
67+ def calc_wigner_9j (
68+ j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float , j7 : float , j8 : float , j9 : float
69+ ) -> float :
70+ """Calculate the Wigner 9j symbol using lru_cache to improve performance."""
71+ return float (sympy_wigner_9j (j1 , j2 , j3 , j4 , j5 , j6 , j7 , j8 , j9 ).evalf ())
72+
73+
74+ def clebsch_gordan_6j (j1 : float , j2 : float , j3 : float , j12 : float , j23 : float , j_tot : float ) -> float :
75+ """Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
76+
77+ We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
78+
79+ See Also:
80+ - https://en.wikipedia.org/wiki/Racah_W-coefficient
81+ - https://en.wikipedia.org/wiki/6-j_symbol
82+
83+ Args:
84+ j1: Spin quantum number 1.
85+ j2: Spin quantum number 2.
86+ j3: Spin quantum number 3.
87+ j12: Total spin quantum number of j1 + j2.
88+ j23: Total spin quantum number of j2 + j3.
89+ j_tot: Total spin quantum number of j1 + j2 + j3.
90+
91+ Returns:
92+ The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
93+
94+ """
95+ prefactor = minus_one_pow (j1 + j2 + j3 + j_tot ) * math .sqrt ((2 * j12 + 1 ) * (2 * j23 + 1 ))
96+ wigner_6j = calc_wigner_6j (j1 , j2 , j12 , j3 , j_tot , j23 )
97+ return prefactor * wigner_6j
98+
99+
100+ def clebsch_gordan_9j (
101+ j1 : float , j2 : float , j12 : float , j3 : float , j4 : float , j34 : float , j13 : float , j24 : float , j_tot : float
102+ ) -> float :
103+ """Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
104+
105+ We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
106+
107+ See Also:
108+ - https://en.wikipedia.org/wiki/9-j_symbol
109+
110+ Args:
111+ j1: Spin quantum number 1.
112+ j2: Spin quantum number 2.
113+ j12: Total spin quantum number of j1 + j2.
114+ j3: Spin quantum number 1.
115+ j4: Spin quantum number 2.
116+ j34: Total spin quantum number of j1 + j2.
117+ j13: Total spin quantum number of j1 + j3.
118+ j24: Total spin quantum number of j2 + j4.
119+ j_tot: Total spin quantum number of j1 + j2 + j3 + j4.
120+
121+ Returns:
122+ The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
123+
124+ """
125+ prefactor = math .sqrt ((2 * j12 + 1 ) * (2 * j34 + 1 ) * (2 * j13 + 1 ) * (2 * j24 + 1 ))
126+ return prefactor * calc_wigner_9j (j1 , j2 , j12 , j3 , j4 , j34 , j13 , j24 , j_tot )
127+
128+
129+ def calc_wigner_3j_with_symmetries (j1 : float , j2 : float , j3 : float , m1 : float , m2 : float , m3 : float ) -> float :
130+ """Calculate the Wigner 3j symbol using symmetries to reduce the number of symbols, that are not cached."""
48131 symmetry_factor : float = 1
49132
50133 # even permutation -> sort smallest j to be j1
@@ -65,17 +148,11 @@ def calc_wigner_3j(j1: float, j2: float, j3: float, m1: float, m2: float, m3: fl
65148
66149 # TODO Regge symmetries
67150
68- return symmetry_factor * _calc_wigner_3j (j1 , j2 , j3 , m1 , m2 , m3 )
151+ return symmetry_factor * calc_wigner_3j (j1 , j2 , j3 , m1 , m2 , m3 )
69152
70153
71- @lru_cache (maxsize = 10_000 )
72- @sympify_args
73- def _calc_wigner_3j (j1 : float , j2 : float , j3 : float , m1 : float , m2 : float , m3 : float ) -> float :
74- return float (sympy_wigner_3j (j1 , j2 , j3 , m1 , m2 , m3 ).evalf ())
75-
76-
77- def calc_wigner_6j (j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float ) -> float :
78- """Calculate the Wigner 6j symbol using symmetries and lru_cache to improve performance."""
154+ def calc_wigner_6j_with_symmetries (j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float ) -> float :
155+ """Calculate the Wigner 6j symbol using symmetries to reduce the number of symbols, that are not cached."""
79156 # interchange upper and lower for 2 columns -> make j1 < j4 and j2 < j5
80157 if j4 < j1 :
81158 j1 , j2 , j3 , j4 , j5 , j6 = j4 , j2 , j6 , j1 , j5 , j3 # noqa: PLW0127
@@ -91,19 +168,13 @@ def calc_wigner_6j(j1: float, j2: float, j3: float, j4: float, j5: float, j6: fl
91168 if j3 < j2 :
92169 j1 , j2 , j3 , j4 , j5 , j6 = j1 , j3 , j2 , j4 , j6 , j5 # noqa: PLW0127
93170
94- return _calc_wigner_6j (j1 , j2 , j3 , j4 , j5 , j6 )
95-
171+ return calc_wigner_6j (j1 , j2 , j3 , j4 , j5 , j6 )
96172
97- @lru_cache (maxsize = 100_000 )
98- @sympify_args
99- def _calc_wigner_6j (j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float ) -> float :
100- return float (sympy_wigner_6j (j1 , j2 , j3 , j4 , j5 , j6 ).evalf ())
101173
102-
103- def calc_wigner_9j (
174+ def calc_wigner_9j_with_symmetries (
104175 j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float , j7 : float , j8 : float , j9 : float
105176) -> float :
106- """Calculate the Wigner 9j symbol using symmetries and lru_cache to improve performance ."""
177+ """Calculate the Wigner 9j symbol using symmetries to reduce the number of symbols, that are not cached ."""
107178 symmetry_factor : float = 1
108179 js = [j1 , j2 , j3 , j4 , j5 , j6 , j7 , j8 , j9 ]
109180
@@ -132,78 +203,22 @@ def calc_wigner_9j(
132203 if js [3 ] < js [1 ]:
133204 js = [js [0 ], js [3 ], js [6 ], js [1 ], js [4 ], js [7 ], js [2 ], js [5 ], js [8 ]]
134205
135- return symmetry_factor * _calc_wigner_9j (* js )
136-
137-
138- @lru_cache (maxsize = 10_000 )
139- @sympify_args
140- def _calc_wigner_9j (
141- j1 : float , j2 : float , j3 : float , j4 : float , j5 : float , j6 : float , j7 : float , j8 : float , j9 : float
142- ) -> float :
143- return float (sympy_wigner_9j (j1 , j2 , j3 , j4 , j5 , j6 , j7 , j8 , j9 ).evalf ())
144-
145-
146- def clebsch_gordan_6j (j1 : float , j2 : float , j3 : float , j12 : float , j23 : float , j_tot : float ) -> float :
147- """Calculate the overlap between <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
148-
149- We follow the convention of equation (6.1.5) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
150-
151- See Also:
152- - https://en.wikipedia.org/wiki/Racah_W-coefficient
153- - https://en.wikipedia.org/wiki/6-j_symbol
154-
155- Args:
156- j1: Spin quantum number 1.
157- j2: Spin quantum number 2.
158- j3: Spin quantum number 3.
159- j12: Total spin quantum number of j1 + j2.
160- j23: Total spin quantum number of j2 + j3.
161- j_tot: Total spin quantum number of j1 + j2 + j3.
162-
163- Returns:
164- The Clebsch-Gordan coefficient <((j1,j2)j12,j3)j_tot|(j1,(j2,j3)j23)j_tot>.
206+ return symmetry_factor * calc_wigner_9j (* js )
165207
166- """
167- prefactor = minus_one_pow (j1 + j2 + j3 + j_tot ) * math .sqrt ((2 * j12 + 1 ) * (2 * j23 + 1 ))
168- wigner_6j = calc_wigner_6j (j1 , j2 , j12 , j3 , j_tot , j23 )
169- return prefactor * wigner_6j
170208
171-
172- def clebsch_gordan_9j (
173- j1 : float , j2 : float , j12 : float , j3 : float , j4 : float , j34 : float , j13 : float , j24 : float , j_tot : float
174- ) -> float :
175- """Calculate the overlap between <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
176-
177- We follow the convention of equation (6.4.2) from Edmonds 1985 "Angular Momentum in Quantum Mechanics".
178-
179- See Also:
180- - https://en.wikipedia.org/wiki/9-j_symbol
181-
182- Args:
183- j1: Spin quantum number 1.
184- j2: Spin quantum number 2.
185- j12: Total spin quantum number of j1 + j2.
186- j3: Spin quantum number 1.
187- j4: Spin quantum number 2.
188- j34: Total spin quantum number of j1 + j2.
189- j13: Total spin quantum number of j1 + j3.
190- j24: Total spin quantum number of j2 + j4.
191- j_tot: Total spin quantum number of j1 + j2 + j3 + j4.
192-
193- Returns:
194- The Clebsch-Gordan coefficient <((j1,j2)j12,(j3,j4)j34))j_tot|((j1,j3)j13,(j2,j4)j24))j_tot>.
195-
196- """
197- prefactor = math .sqrt ((2 * j12 + 1 ) * (2 * j34 + 1 ) * (2 * j13 + 1 ) * (2 * j24 + 1 ))
198- return prefactor * calc_wigner_9j (j1 , j2 , j12 , j3 , j4 , j34 , j13 , j24 , j_tot )
209+ if USE_SYMMETRIES :
210+ calc_wigner_3j = calc_wigner_3j_with_symmetries # type: ignore [assignment]
211+ calc_wigner_6j = calc_wigner_6j_with_symmetries # type: ignore [assignment]
212+ calc_wigner_9j = calc_wigner_9j_with_symmetries # type: ignore [assignment]
199213
200214
201215def minus_one_pow (n : float ) -> int :
216+ """Calculate (-1)^n for an integer n and raise an error if n is not an integer."""
202217 if n % 2 == 0 :
203218 return 1
204219 if n % 2 == 1 :
205220 return - 1
206- raise ValueError (f"Invalid input { n } ." )
221+ raise ValueError (f"minus_one_pow: Invalid input { n = } is not an integer ." )
207222
208223
209224def try_trivial_spin_addition (s_1 : float , s_2 : float , s_tot : float | None , name : str ) -> float :
@@ -235,9 +250,3 @@ def get_possible_quantum_number_list(s_1: float, s_2: float, s_tot: float | None
235250 if s_tot is not None :
236251 return [s_tot ]
237252 return [float (s ) for s in np .arange (abs (s_1 - s_2 ), s_1 + s_2 + 1 , 1 )]
238-
239-
240- if not USE_SYMMETRIES :
241- calc_wigner_3j = _calc_wigner_3j
242- calc_wigner_6j = _calc_wigner_6j
243- calc_wigner_9j = _calc_wigner_9j
0 commit comments