-
Here is a reproducible example: import jax.numpy as jnp
from jax import jit, grad, lax
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'cpu')
@jit
def _poly_coeffs_binary(w, a, e1):
wbar = jnp.conjugate(w)
p_0 = -(a ** 2) + wbar ** 2
p_1 = a ** 2 * w - 2 * a * e1 + a - w * wbar ** 2 + wbar
p_2 = (
2 * a ** 4
- 2 * a ** 2 * wbar ** 2
+ 4 * a * wbar * e1
- 2 * a * wbar
- 2 * w * wbar
)
p_3 = (
-2 * a ** 4 * w
+ 4 * a ** 3 * e1
- 2 * a ** 3
+ 2 * a ** 2 * w * wbar ** 2
- 4 * a * w * wbar * e1
+ 2 * a * w * wbar
+ 2 * a * e1
- a
- w
)
p_4 = (
-(a ** 6)
+ a ** 4 * wbar ** 2
- 4 * a ** 3 * wbar * e1
+ 2 * a ** 3 * wbar
+ 2 * a ** 2 * w * wbar
+ 4 * a ** 2 * e1 ** 2
- 4 * a ** 2 * e1
+ 2 * a ** 2
- 4 * a * w * e1
+ 2 * a * w
)
p_5 = (
a ** 6 * w
- 2 * a ** 5 * e1
+ a ** 5
- a ** 4 * w * wbar ** 2
- a ** 4 * wbar
+ 4 * a ** 3 * w * wbar * e1
- 2 * a ** 3 * w * wbar
+ 2 * a ** 3 * e1
- a ** 3
- 4 * a ** 2 * w * e1 ** 2
+ 4 * a ** 2 * w * e1
- a ** 2 * w
)
p = jnp.stack([p_0, p_1, p_2, p_3, p_4, p_5])
return jnp.moveaxis(p, 0, -1)
@jit
def lens_eq_binary(z, a, e1):
zbar = jnp.conjugate(z)
return z - e1 / (zbar - a) - (1.0 - e1) / (zbar + a)
@jit
def lens_eq_jac_det_binary(z, a, e1):
zbar = jnp.conjugate(z)
return 1.0 - jnp.abs(e1 / (zbar - a) ** 2 + (1.0 - e1) / (zbar + a) ** 2) ** 2
@jit
def images_point_source_binary(w, a, e1):
# Compute complex polynomial coefficients for each element of w
coeffs = _poly_coeffs_binary(w, a, e1)
# Compute roots
roots = jnp.roots(coeffs, strip_zeros=False)
roots = jnp.moveaxis(roots, -1, 0)
# Evaluate the lens equation at the roots
lens_eq_eval = lens_eq_binary(roots, a, e1) - w
# Mask out roots which don't satisfy the lens equation
mask_solutions = jnp.abs(lens_eq_eval) < 1e-5
return roots, mask_solutions
@jit
def mag_point_source_binary(w, a, e1):
images, mask = images_point_source_binary(
w, a, e1,
)
det = lens_eq_jac_det_binary(images, a, e1)
mag = (1.0 / jnp.abs(det)) * mask
return mag.sum(axis=0).reshape(w.shape)
w = 0. + 0.3j
f = lambda w: mag_point_source_binary(w, 0.5*0.9, 0.8)
%%timeit
f(w).block_until_ready()
%%timeit
grad(f)(w).block_until_ready() Calling |
Beta Was this translation helpful? Give feedback.
Answered by
mattjj
Apr 14, 2022
Replies: 1 comment 4 replies
-
Try putting In [4]: %timeit f(w).block_until_ready()
24.1 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [5]: timeit grad(f)(w).block_until_ready()
3.27 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [6]: grad_f = jit(grad(f))
In [7]: timeit grad_f(w).block_until_ready()
32.2 µs ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) |
Beta Was this translation helpful? Give feedback.
4 replies
Answer selected by
fbartolic
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Try putting
jit
on the outside ofgrad
so that we can push more of the computation to XLA: