Best way to speed up optimization using JAX #19331
-
Hi all, I'm trying to figure out the best way to optimize a piece of my code. For context, I am given a joint probability distribution (encoded as a characteristic function), and need to compute the maximum value of its absolute value along a particular direction. My code to do this, using JAX, is given below. import jax.numpy as jnp
from jax import jit, grad, value_and_grad
from jaxopt import LBFGS, BFGS
import time
# Define the original function
@jit
def joint_cf(t, mu, Sigma):
return jnp.exp(1j * jnp.dot(mu, t) - 0.5 * jnp.dot(t, jnp.dot(Sigma, t)))
# Define a function that replaces the j-th element of t with s and computes the modulus
@jit
def modulus_at_s(s, t, mu, Sigma, j):
# Replace the j-th component with s and keep the rest of t the same
t = t.at[j].set(s[0])
return -jnp.abs(joint_cf(t, mu, Sigma)) # negative for minimization
modulus_value_and_grad = jit(value_and_grad(modulus_at_s, argnums=0))
# Example usage
mu = jnp.array([1.0, 2.0])
Sigma = jnp.array([[1.0, 0.2], [0.2, 1.0]])
t = jnp.array([0.5, -0.3])
# Choose the index j for the direction of optimization
j = 1 # For example, the second component
# Initial guess for s
initial_s = jnp.array([t[j]])
# Initialize optimizers
# optimizer_LBFGS = LBFGS(modulus_value_and_grad, value_and_grad=True, maxiter=100) # Limited (memory) BFGS
optimizer_LBFGS = LBFGS(modulus_at_s, maxiter=100) # Limited (memory) BFGS
# # # WARM-START # # #
tic = time.perf_counter()
result_LBFGS = optimizer_LBFGS.run(initial_s, t, mu, Sigma, j)
toc = time.perf_counter()
elapsed_time_ms_WS = (toc - tic) * 1000
print(f"Warm-start optimization time (L-BFGS): {elapsed_time_ms_WS:.3f} ms")
# # # MAIN # # #
tic = time.perf_counter()
result_LBFGS = optimizer_LBFGS.run(initial_s, t, mu, Sigma, j)
toc = time.perf_counter()
elapsed_time_ms_main = (toc - tic) * 1000
print(f"Main optimization time (L-BFGS): {elapsed_time_ms_main:.3f} ms") Currently, the performance of this code is around 1500 ms for the warm-start, and 500 ms for the main run, which is quite slow for my application. I've applied JIT compilation for all the relevant functions, and using the preferred |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
From the JAX side of things, everything looks fine to me. It sounds like your question has more to do with the performance of |
Beta Was this translation helpful? Give feedback.
From the JAX side of things, everything looks fine to me. It sounds like your question has more to do with the performance of
jaxopt
optimization routines; for that you might find more help by asking at the jaxopt repository.