Skip to content

How to diagnose/profile performance issues #1177

@toddpocuca

Description

@toddpocuca

I am working on a probabilistic programming language using JAX + Equinox but am running into an issue where I can't seem to get beyond 50-60% per-core CPU utilization when the size of the problem is big enough that JAX's cost modelling switches to enable multi-threading, and would like to diagnose where the issue is coming from.

I've went through https://docs.jax.dev/en/latest/profiling.html and tried to understand the trace of various jitted functions, and I did notice there seem to be a lot of copies, but I cannot pinpoint exactly what lines of code are producing the copies (each op seems to just have an id but doesn't tell me where exactly it's coming from). I've been trying to profile this for a couple of days and can't seem to figure it out so I was hoping someone more experienced with JAX + Equinox could help me.

I'll highlight two snippets of code, the first is the higher-level method for a Variational object that optimizes the variational distribution:

    @eqx.filter_jit
    def fit(
        self,
        max_iters: int,
        learning_rate: float,
        tolerance: float, # note this doesn't do anything yet
        grad_draws: int,
        max_batch_size: int,
        key: PRNGKeyArray = jr.key(0),
        verbose: bool = True,
        print_rate: int = 5000
    ) -> Self:
        """
        Optimize the variational distribution.
        """
        # Create unique identifier for optimization loop
        loop_id = jr.key_data(key).sum()

        # Determine actual batch size for ELBO & gradient computations
        grad_batch_size = grad_draws if max_batch_size >= grad_draws else max_batch_size

        # Partition variational
        dyn, static = eqx.partition(self, self.filter_spec)

        # Construct schedulers
        lr_schedule: Callable = opx.cosine_decay_schedule(
            learning_rate,
            max_iters,
            jnp.finfo(jnp.array(0.0)).eps.item()
        )

        # Initialize optimizer
        optim: GradientTransformation = opx.chain(
            opx.scale(-1.0),
            opx.adamax(lr_schedule)
        )
        opt_state: OptState = optim.init(dyn)

        # Initialize progress bar
        if verbose:
            update_progress(loop_id, 0, max_iters, "Fitting Variational Approximation", print_rate)

        # Helper functions for optimization loop
        def condition(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]) -> Bool[Array, ""]:
            # Unpack iteration state
            dyn, opt_state, i, key = state

            return i < max_iters

        def body(state: Tuple[Self, OptState, Scalar, PRNGKeyArray]) -> Tuple[Self, OptState, Scalar, PRNGKeyArray]:
            # Unpack iteration state
            dyn, opt_state, i, key = state

            # Update iteration
            i = i + 1

            # Update progress bar
            if verbose:
                update_progress(loop_id, i, max_iters, "Fitting Variational Approximation", print_rate)

            # Update PRNG key
            key, _ = jr.split(key)

            # Reconstruct variational
            vari: Self = eqx.combine(dyn, static)

            # Compute ELBO gradient
            update: M = vari.elbo_grad(grad_draws, grad_batch_size, key)

            # Transform update through optimizer
            update, opt_state = optim.update( # type: ignore
                update, opt_state, dyn # type: ignore
            )

            # Update variational distribution
            dyn: Self = eqx.apply_updates(dyn, update)

            return dyn, opt_state, i, key

        # Run optimization loop
        dyn, _, iter, _ = lax.while_loop(
            cond_fun=condition,
            body_fun=body,
            init_val=(dyn, opt_state, jnp.array(0, jnp.uint32), key),
        )

        # Close progress bar
        close_progress(loop_id, iter)

        # Return optimized variational
        return eqx.combine(dyn, static)

The second is the elbo_grad method for a specific variational method NormalizingFlow:

    @eqx.filter_jit
    def elbo_grad(self, n: int, batch_size: int, key: PRNGKeyArray) -> Self:
        dyn, static = eqx.partition(self, self.filter_spec)

        # Define ELBO function
        def elbo(dyn: Self, n: int, key: PRNGKeyArray) -> Scalar:
            self = eqx.combine(dyn, static)

            # Split key
            keys = jr.split(key, n // batch_size)

            # Split ELBO calculation into batches
            def batched_elbo(batch_key: PRNGKeyArray) -> Array:
                # Draw from variational distribution
                draws: Array = self.base.sample(batch_size, key = batch_key)

                # Evaluate posterior and variational densities
                batched_post_evals, batched_vari_evals = self._eval(draws)

                # Compute batched ELBO evals
                batched_elbo_evals: Array = batched_post_evals - batched_vari_evals

                return batched_elbo_evals

            # Compute ELBO evals
            elbo_evals = lax.map(batched_elbo, keys)

            # Average ELBO estimates
            elbo_est = jnp.mean(elbo_evals)

            return elbo_est

        # Map to its gradient
        elbo_grad: Callable[
            [Self, int, PRNGKeyArray], Self
        ] = eqx.filter_grad(elbo)

        return elbo_grad(dyn, n, key)

Note that the batch_size is always set to n in my tests.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions