-
-
Notifications
You must be signed in to change notification settings - Fork 182
How to diagnose/profile performance issues #1177
Description
I am working on a probabilistic programming language using JAX + Equinox but am running into an issue where I can't seem to get beyond 50-60% per-core CPU utilization when the size of the problem is big enough that JAX's cost modelling switches to enable multi-threading, and would like to diagnose where the issue is coming from.
I've went through https://docs.jax.dev/en/latest/profiling.html and tried to understand the trace of various jitted functions, and I did notice there seem to be a lot of copies, but I cannot pinpoint exactly what lines of code are producing the copies (each op seems to just have an id but doesn't tell me where exactly it's coming from). I've been trying to profile this for a couple of days and can't seem to figure it out so I was hoping someone more experienced with JAX + Equinox could help me.
I'll highlight two snippets of code, the first is the higher-level method for a Variational object that optimizes the variational distribution:
@eqx.filter_jit
def fit(
self,
max_iters: int,
learning_rate: float,
tolerance: float, # note this doesn't do anything yet
grad_draws: int,
max_batch_size: int,
key: PRNGKeyArray = jr.key(0),
verbose: bool = True,
print_rate: int = 5000
) -> Self:
"""
Optimize the variational distribution.
"""
# Create unique identifier for optimization loop
loop_id = jr.key_data(key).sum()
# Determine actual batch size for ELBO & gradient computations
grad_batch_size = grad_draws if max_batch_size >= grad_draws else max_batch_size
# Partition variational
dyn, static = eqx.partition(self, self.filter_spec)
# Construct schedulers
lr_schedule: Callable = opx.cosine_decay_schedule(
learning_rate,
max_iters,
jnp.finfo(jnp.array(0.0)).eps.item()
)
# Initialize optimizer
optim: GradientTransformation = opx.chain(
opx.scale(-1.0),
opx.adamax(lr_schedule)
)
opt_state: OptState = optim.init(dyn)
# Initialize progress bar
if verbose:
update_progress(loop_id, 0, max_iters, "Fitting Variational Approximation", print_rate)
# Helper functions for optimization loop
def condition(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]) -> Bool[Array, ""]:
# Unpack iteration state
dyn, opt_state, i, key = state
return i < max_iters
def body(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]) -> Tuple[Self, OptState, Scalar, PRNGKeyArray]:
# Unpack iteration state
dyn, opt_state, i, key = state
# Update iteration
i = i + 1
# Update progress bar
if verbose:
update_progress(loop_id, i, max_iters, "Fitting Variational Approximation", print_rate)
# Update PRNG key
key, _ = jr.split(key)
# Reconstruct variational
vari: Self = eqx.combine(dyn, static)
# Compute ELBO gradient
update: M = vari.elbo_grad(grad_draws, grad_batch_size, key)
# Transform update through optimizer
update, opt_state = optim.update( # type: ignore
update, opt_state, dyn # type: ignore
)
# Update variational distribution
dyn: Self = eqx.apply_updates(dyn, update)
return dyn, opt_state, i, key
# Run optimization loop
dyn, _, iter, _ = lax.while_loop(
cond_fun=condition,
body_fun=body,
init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
)
# Close progress bar
close_progress(loop_id, iter)
# Return optimized variational
return eqx.combine(dyn, static)The second is the elbo_grad method for a specific variational method NormalizingFlow:
@eqx.filter_jit
def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
dyn, static = eqx.partition(self, self.filter_spec)
# Define ELBO function
def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
self = eqx.combine(dyn, static)
# Split key
keys = jr.split(key, n // batch_size)
# Split ELBO calculation into batches
def batched_elbo(batch_key: PRNGKeyArray) -> Array:
# Draw from variational distribution
draws: Array = self.base.sample(batch_size, key = batch_key)
# Evaluate posterior and variational densities
batched_post_evals, batched_vari_evals = self._eval(draws)
# Compute batched ELBO evals
batched_elbo_evals: Array = batched_post_evals - batched_vari_evals
return batched_elbo_evals
# Compute ELBO evals
elbo_evals = lax.map(batched_elbo, keys)
# Average ELBO estimates
elbo_est = jnp.mean(elbo_evals)
return elbo_est
# Map to its gradient
elbo_grad: Callable[
[Self, int, PRNGKeyArray], Self
] = eqx.filter_grad(elbo)
return elbo_grad(dyn, n, key)Note that the
batch_sizeis always set tonin my tests.