|
| 1 | +import jax.numpy as jnp |
| 2 | +from jax import jit |
| 3 | +from functools import partial |
| 4 | + |
| 5 | + |
| 6 | +@partial(jit, static_argnums=(1, 2, 3)) |
| 7 | +def compute_full(dl: jnp.ndarray, beta: float, L: int, el: int) -> jnp.ndarray: |
| 8 | + if el == 0: |
| 9 | + dl = dl.at[el + L - 1, el + L - 1].set(1.0) |
| 10 | + return dl |
| 11 | + elif el == 1: |
| 12 | + cosb = jnp.cos(beta) |
| 13 | + sinb = jnp.sin(beta) |
| 14 | + |
| 15 | + coshb = jnp.cos(beta / 2.0) |
| 16 | + sinhb = jnp.sin(beta / 2.0) |
| 17 | + sqrt2 = jnp.sqrt(2.0) |
| 18 | + |
| 19 | + dl = dl.at[L - 2, L - 2].set(coshb**2) |
| 20 | + dl = dl.at[L - 2, L - 1].set(sinb / sqrt2) |
| 21 | + dl = dl.at[L - 2, L].set(sinhb**2) |
| 22 | + |
| 23 | + dl = dl.at[L - 1, L - 2].set(-sinb / sqrt2) |
| 24 | + dl = dl.at[L - 1, L - 1].set(cosb) |
| 25 | + dl = dl.at[L - 1, L].set(sinb / sqrt2) |
| 26 | + |
| 27 | + dl = dl.at[L, L - 2].set(sinhb**2) |
| 28 | + dl = dl.at[L, L - 1].set(-sinb / sqrt2) |
| 29 | + dl = dl.at[L, L].set(coshb**2) |
| 30 | + return dl |
| 31 | + else: |
| 32 | + coshb = -jnp.cos(beta / 2.0) |
| 33 | + sinhb = jnp.sin(beta / 2.0) |
| 34 | + dd = jnp.zeros((2 * el + 2, 2 * el + 2)) |
| 35 | + |
| 36 | + # First pass |
| 37 | + j = 2 * el - 1 |
| 38 | + i = jnp.arange(j) |
| 39 | + k = jnp.arange(j) |
| 40 | + |
| 41 | + sqrt_jmk = jnp.sqrt(j - k) |
| 42 | + sqrt_kp1 = jnp.sqrt(k + 1) |
| 43 | + sqrt_jmi = jnp.sqrt(j - i) |
| 44 | + sqrt_ip1 = jnp.sqrt(i + 1) |
| 45 | + |
| 46 | + dlj = dl[k - (el - 1) + L - 1][:, i - (el - 1) + L - 1] |
| 47 | + |
| 48 | + dd = dd.at[:j, :j].add( |
| 49 | + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) * dlj * coshb |
| 50 | + ) |
| 51 | + dd = dd.at[:j, 1 : j + 1].add( |
| 52 | + jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) * dlj * sinhb |
| 53 | + ) |
| 54 | + dd = dd.at[1 : j + 1, :j].add( |
| 55 | + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) * dlj * sinhb |
| 56 | + ) |
| 57 | + dd = dd.at[1 : j + 1, 1 : j + 1].add( |
| 58 | + jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) * dlj * coshb |
| 59 | + ) |
| 60 | + |
| 61 | + dl = dl.at[-el + L - 1 : el + 1 + L - 1, -el + L - 1 : el + 1 + L - 1].multiply( |
| 62 | + 0.0 |
| 63 | + ) |
| 64 | + |
| 65 | + j = 2 * el |
| 66 | + i = jnp.arange(j) |
| 67 | + k = jnp.arange(j) |
| 68 | + |
| 69 | + # Second pass |
| 70 | + sqrt_jmk = jnp.sqrt(j - k) |
| 71 | + sqrt_kp1 = jnp.sqrt(k + 1) |
| 72 | + sqrt_jmi = jnp.sqrt(j - i) |
| 73 | + sqrt_ip1 = jnp.sqrt(i + 1) |
| 74 | + |
| 75 | + dl = dl.at[-el + L - 1 : el + L - 1, -el + L - 1 : el + L - 1].add( |
| 76 | + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_jmk, optimize=True) |
| 77 | + * dd[:j, :j] |
| 78 | + * coshb, |
| 79 | + ) |
| 80 | + dl = dl.at[-el + L - 1 : el + L - 1, L - el : L + el].add( |
| 81 | + jnp.einsum("i,k->ki", -sqrt_ip1, sqrt_jmk, optimize=True) |
| 82 | + * dd[:j, :j] |
| 83 | + * sinhb, |
| 84 | + ) |
| 85 | + dl = dl.at[L - el : L + el, -el + L - 1 : el + L - 1].add( |
| 86 | + jnp.einsum("i,k->ki", sqrt_jmi, sqrt_kp1, optimize=True) |
| 87 | + * dd[:j, :j] |
| 88 | + * sinhb, |
| 89 | + ) |
| 90 | + dl = dl.at[L - el : L + el, L - el : L + el].add( |
| 91 | + jnp.einsum("i,k->ki", sqrt_ip1, sqrt_kp1, optimize=True) |
| 92 | + * dd[:j, :j] |
| 93 | + * coshb, |
| 94 | + ) |
| 95 | + return dl / ((2 * el) * (2 * el - 1)) |
0 commit comments