|
1 | 1 | from functools import partial |
2 | | -from typing import Tuple |
3 | 2 |
|
4 | 3 | import jax.numpy as jnp |
5 | 4 | import numpy as np |
|
9 | 8 | def wigner_subset_to_s2( |
10 | 9 | flmn: np.ndarray, |
11 | 10 | spins: np.ndarray, |
12 | | - DW: Tuple[np.ndarray, np.ndarray], |
| 11 | + DW: tuple[np.ndarray, np.ndarray], |
13 | 12 | L: int, |
14 | 13 | sampling: str = "mw", |
15 | 14 | ) -> np.ndarray: |
@@ -91,7 +90,7 @@ def wigner_subset_to_s2( |
91 | 90 | def wigner_subset_to_s2_jax( |
92 | 91 | flmn: jnp.ndarray, |
93 | 92 | spins: jnp.ndarray, |
94 | | - DW: Tuple[jnp.ndarray, jnp.ndarray], |
| 93 | + DW: tuple[jnp.ndarray, jnp.ndarray], |
95 | 94 | L: int, |
96 | 95 | sampling: str = "mw", |
97 | 96 | ) -> jnp.ndarray: |
@@ -173,7 +172,7 @@ def wigner_subset_to_s2_jax( |
173 | 172 | def so3_to_wigner_subset( |
174 | 173 | f: np.ndarray, |
175 | 174 | spins: np.ndarray, |
176 | | - DW: Tuple[np.ndarray, np.ndarray], |
| 175 | + DW: tuple[np.ndarray, np.ndarray], |
177 | 176 | L: int, |
178 | 177 | N: int, |
179 | 178 | sampling: str = "mw", |
@@ -214,7 +213,7 @@ def so3_to_wigner_subset( |
214 | 213 | def so3_to_wigner_subset_jax( |
215 | 214 | f: jnp.ndarray, |
216 | 215 | spins: jnp.ndarray, |
217 | | - DW: Tuple[jnp.ndarray, jnp.ndarray], |
| 216 | + DW: tuple[jnp.ndarray, jnp.ndarray], |
218 | 217 | L: int, |
219 | 218 | N: int, |
220 | 219 | sampling: str = "mw", |
@@ -257,7 +256,7 @@ def so3_to_wigner_subset_jax( |
257 | 256 | def s2_to_wigner_subset( |
258 | 257 | fs: np.ndarray, |
259 | 258 | spins: np.ndarray, |
260 | | - DW: Tuple[np.ndarray, np.ndarray], |
| 259 | + DW: tuple[np.ndarray, np.ndarray], |
261 | 260 | L: int, |
262 | 261 | sampling: str = "mw", |
263 | 262 | ) -> np.ndarray: |
@@ -343,7 +342,7 @@ def s2_to_wigner_subset( |
343 | 342 | def s2_to_wigner_subset_jax( |
344 | 343 | fs: jnp.ndarray, |
345 | 344 | spins: jnp.ndarray, |
346 | | - DW: Tuple[jnp.ndarray, jnp.ndarray], |
| 345 | + DW: tuple[jnp.ndarray, jnp.ndarray], |
347 | 346 | L: int, |
348 | 347 | sampling: str = "mw", |
349 | 348 | ) -> jnp.ndarray: |
|
0 commit comments