diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c530af8d9..c60fcac68 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -27,7 +27,6 @@ import pytensor.tensor as pt from arviz.data.base import make_attrs -from jax.lax import scan from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -158,56 +157,14 @@ def logp_fn_wrap(x): return logp_fn_wrap -def _get_log_likelihood( - model: Model, - samples, - backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", -) -> dict: - """Compute log-likelihood for all observations""" +def _get_log_likelihood(model: Model, samples) -> Callable: + """Generate function to compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) - result = _postprocess_samples( - jax_fn, - samples, - backend, - postprocessing_vectorize=postprocessing_vectorize, - donate_samples=False, - ) + result = jax.vmap(jax_fn)(*samples) return {v.name: r for v, r in zip(model.observed_RVs, result)} -def _device_put(input, device: str): - return jax.device_put(input, jax.devices(device)[0]) - - -def _postprocess_samples( - jax_fn: Callable, - raw_mcmc_samples: list[TensorVariable], - postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "vmap", - donate_samples: bool = False, -) -> list[TensorVariable]: - if postprocessing_vectorize == "scan": - t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples] - jax_vfn = jax.vmap(jax_fn) - _, outs = scan( - lambda _, x: ((), jax_vfn(*x)), - (), - _device_put(t_raw_mcmc_samples, postprocessing_backend), - ) - return [jnp.swapaxes(t, 0, 1) for t in outs] - elif postprocessing_vectorize == "vmap": - - def process_fn(x): - return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend)) - - return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples) - - else: - raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}") - - def _get_batched_jittered_initial_points( model: Model, chains: int, @@ -238,52 +195,37 @@ def _get_batched_jittered_initial_points( return [np.stack(init_state) for init_state in zip(*initial_points_values)] -def _blackjax_inference_loop( - seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs -): - import blackjax +@partial(jax.jit, donate_argnums=0) +def _set_tree(store, input, idx): + """Update pytree of outputs - used for saving results of chunked sampling""" - from blackjax.adaptation.base import get_filter_adapt_info_fn + def update_fn(save, inp): + starts = (save.shape[0], idx, *([0] * (len(save.shape) - 2))) + return jax.lax.dynamic_update_slice(save, inp, starts) - algorithm_name = adaptation_kwargs.pop("algorithm", "nuts") - if algorithm_name == "nuts": - algorithm = blackjax.nuts - elif algorithm_name == "hmc": - algorithm = blackjax.hmc - else: - raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.") - - adapt = blackjax.window_adaptation( - algorithm=algorithm, - logdensity_fn=logprob_fn, - target_acceptance_rate=target_accept, - adaptation_info_fn=get_filter_adapt_info_fn(), - **adaptation_kwargs, - ) - (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) - kernel = algorithm(logprob_fn, **tuned_params).step + store = jax.tree.map(update_fn, store, input) + return store - def _one_step(state, xs): - _, rng_key = xs - state, info = kernel(rng_key, state) - position = state.position - stats = { - "diverging": info.is_divergent, - "energy": info.energy, - "tree_depth": info.num_trajectory_expansions, - "n_steps": info.num_integration_steps, - "acceptance_rate": info.acceptance_rate, - "lp": state.logdensity, - } - return state, (position, stats) - progress_bar = adaptation_kwargs.pop("progress_bar", False) +def _gen_arr(inp, nchunk): + """Generate output array on cpu for chunked sampling""" + shape = (inp.shape[0] * nchunk, *inp.shape[1:]) + return jnp.zeros(shape, dtype=inp.dtype, device=jax.devices("cpu")[0]) - keys = jax.random.split(seed, draws) - scan_fn = blackjax.progress_bar.gen_scan_fn(draws, progress_bar) - _, (samples, stats) = scan_fn(_one_step, last_state, (jnp.arange(draws), keys)) - return samples, stats +def _do_chunked_sampling(last_state, output, nchunk, nsteps, sample_fn, progressbar): + """Run chunked sampling saving to output on the cpu""" + for i in range(1, nchunk): + if progressbar: + logger.info("Sampling chunk %d of %d:" % (i + 1, nchunk)) + last_state, tmpout = sample_fn(last_state) + output = _set_tree( + output, + jax.device_put(tmpout, jax.devices("cpu")[0]), + nsteps * i, + ) + del tmpout + return last_state, output def _sample_blackjax_nuts( @@ -296,72 +238,21 @@ def _sample_blackjax_nuts( progressbar: bool, random_seed: int, initial_points, + postprocess_fn, nuts_kwargs, + num_chunks: int = 1, ) -> az.InferenceData: - """ - Draw samples from the posterior using the NUTS method from the ``blackjax`` library. - - Parameters - ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - tune : int, default 1000 - Number of iterations to tune. Samplers adjust the step sizes, scalings or - similar during tuning. Tuning samples will be drawn in addition to the number - specified in the ``draws`` argument. - chains : int, default 4 - The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, RandomState or Generator, optional - Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. - - Returns - ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). - """ - import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn + # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap elif chain_method == "vectorized": - map_fn = jax.vmap + + def map_fn(x): + return jax.jit(jax.vmap(x)) else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' @@ -372,41 +263,108 @@ def _sample_blackjax_nuts( logprob_fn = get_jaxified_logp(model) - seed = jax.random.PRNGKey(random_seed) - keys = jax.random.split(seed, chains) + s1, s2 = jax.random.split(jax.random.PRNGKey(random_seed)) + adapt_seed = jax.random.split(s1, chains) + sample_seed = jax.random.split(s2, chains) - nuts_kwargs["progress_bar"] = progressbar - get_posterior_samples = partial( - _blackjax_inference_loop, - logprob_fn=logprob_fn, - tune=tune, - draws=draws, - target_accept=target_accept, - **nuts_kwargs, + algorithm_name = nuts_kwargs.pop("algorithm", "nuts") + if algorithm_name == "nuts": + algorithm = blackjax.nuts + elif algorithm_name == "hmc": + algorithm = blackjax.hmc + else: + raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.") + + assert draws % num_chunks == 0 + nsteps = draws // num_chunks + + # Run adaptation for sampling parameters + @map_fn + def run_adaptation(seed, init_position): + return blackjax.window_adaptation( + algorithm=algorithm, + logdensity_fn=logprob_fn, + target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), + progress_bar=progressbar, + **nuts_kwargs, + ).run(seed, init_position, num_steps=tune) + + (adapt_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) + + # Filters output from each sampling step + def _transform_fn(state, info): + position = state.position + stats = { + "diverging": info.is_divergent, + "energy": info.energy, + "tree_depth": info.num_trajectory_expansions, + "n_steps": info.num_integration_steps, + "acceptance_rate": info.acceptance_rate, + "lp": state.logdensity, + } + return position, stats + + # Performs sampling for each chunk + # random keys are carried with state + @map_fn + @partial(jax.jit, donate_argnums=0) + def _multi_step(state, imm, ss): + state, key = state + key, _skey = jax.random.split(key) + last_state, (raw_samples, stats) = blackjax.util.run_inference_algorithm( + _skey, + algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss), + num_steps=nsteps, + initial_state=state, + progress_bar=progressbar, + transform=_transform_fn, + ) + samples, log_likelihoods = postprocess_fn(raw_samples) + return (last_state, key), ((samples, log_likelihoods), stats) + + chunk_sample_fn = partial( + _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] ) - raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points) - return raw_mcmc_samples, sample_stats, blackjax + if progressbar: + logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) + + # Sample first chunk + last_state, sample_data = chunk_sample_fn((adapt_state, sample_seed)) + + # If single chunk sampling return results on device + if num_chunks == 1: + ((samples, log_likelihoods), stats) = sample_data + return samples, stats, log_likelihoods, blackjax + + # Provision space for all samples on the cpu + save first chunk + output = _set_tree( + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), sample_data), + jax.device_put(sample_data, jax.devices("cpu")[0]), + 0, + ) + del sample_data + + # Sample remaining chunks + _, ((samples, log_likelihoods), stats) = _do_chunked_sampling( + last_state, output, num_chunks, nsteps, chunk_sample_fn, progressbar + ) + return samples, stats, log_likelihoods, blackjax -# Adopted from arviz numpyro extractor def _numpyro_stats_to_dict(posterior): """Extract sample_stats from NumPyro posterior.""" - rename_key = { - "potential_energy": "lp", - "adapt_state.step_size": "step_size", - "num_steps": "n_steps", - "accept_prob": "acceptance_rate", + extra_fields = posterior.get_extra_fields(group_by_chain=True) + data = { + "lp": extra_fields["potential_energy"], + "step_size": extra_fields["adapt_state.step_size"], + "n_steps": extra_fields["num_steps"], + "acceptance_rate": extra_fields["accept_prob"], + "tree_depth": jnp.log2(extra_fields["num_steps"]).astype(int) + 1, + "energy": extra_fields["energy"], + "diverging": extra_fields["diverging"], } - data = {} - for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): - if isinstance(value, dict | tuple): - continue - name = rename_key.get(stat, stat) - value = value.copy() - data[name] = value - if stat == "num_steps": - data["tree_depth"] = np.log2(value).astype(int) + 1 return data @@ -420,12 +378,17 @@ def _sample_numpyro_nuts( progressbar: bool, random_seed: int, initial_points, + postprocess_fn, nuts_kwargs: dict[str, Any], + num_chunks: int = 1, ): import numpyro from numpyro.infer import MCMC, NUTS + assert draws % num_chunks == 0 + nsteps = draws // num_chunks + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kwargs.setdefault("adapt_step_size", True) @@ -441,33 +404,61 @@ def _sample_numpyro_nuts( pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, - num_samples=draws, + num_samples=nsteps, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progressbar, ) - map_seed = jax.random.PRNGKey(random_seed) - if chains > 1: - map_seed = jax.random.split(map_seed, chains) - - pmap_numpyro.run( - map_seed, - init_params=initial_points, - extra_fields=( - "num_steps", - "potential_energy", - "energy", - "adapt_state.step_size", - "accept_prob", - "diverging", - ), + extra_fields = ( + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", ) + vmap_postprocess = jax.jit(jax.vmap(postprocess_fn)) + + key = jax.random.PRNGKey(random_seed) + if progressbar: + logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) + pmap_numpyro.run(key, init_params=initial_points, extra_fields=extra_fields) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) - sample_stats = _numpyro_stats_to_dict(pmap_numpyro) - return raw_mcmc_samples, sample_stats, numpyro + stats = _numpyro_stats_to_dict(pmap_numpyro) + samples = vmap_postprocess(raw_mcmc_samples) + + if num_chunks == 1: + return samples[0], stats, samples[1], numpyro + + def sample_chunk(state): + pmap_numpyro.post_warmup_state = state + pmap_numpyro.run(pmap_numpyro.post_warmup_state.rng_key, extra_fields=extra_fields) + + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) + sample_stats = _numpyro_stats_to_dict(pmap_numpyro) + mcmc_samples, likelihoods = vmap_postprocess(raw_mcmc_samples) + return pmap_numpyro.last_state, ((mcmc_samples, likelihoods), sample_stats) + + output = _set_tree( + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)), + jax.device_put((samples, stats), jax.devices("cpu")[0]), + 0, + ) + del samples, stats + + _, (all_samples, all_stats) = _do_chunked_sampling( + pmap_numpyro.last_state, + output, + num_chunks, + nsteps, + sample_chunk, + progressbar, + ) + + return all_samples[0], all_stats, all_samples[1], numpyro def sample_jax_nuts( @@ -485,12 +476,13 @@ def sample_jax_nuts( progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] | None = None, - postprocessing_chunks=None, + postprocessing_backend: Literal["cpu", "gpu"] | None = None, # Note unused + postprocessing_vectorize: Literal["vmap", "scan"] | None = None, # Note unused + postprocessing_chunks=None, # Note unused idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, nuts_sampler: Literal["numpyro", "blackjax"], + num_chunks: int = 1, ) -> az.InferenceData: """ Draw samples from the posterior using a jax NUTS method. @@ -551,6 +543,9 @@ def sample_jax_nuts( nuts_sampler : Literal["numpyro", "blackjax"] Nuts sampler library to use - do not change - use sample_numpyro_nuts or sample_blackjax_nuts as appropriate + num_chunks : int + Splits sampling into multiple chunks and collects them on the cpu. Reduces gpu memory + usage when sampling. There is no benefit when sampling on the cpu. Returns ------- @@ -568,6 +563,16 @@ def sample_jax_nuts( DeprecationWarning, ) + if postprocessing_backend is not None: + import warnings + + warnings.warn( + "postprocessing_backend={'cpu', 'gpu'} will be removed in a future release, " + "postprocessing will be done on sampling device in the future. If device memory " + "consumption is an issue please use num_chunks to reduce consumption.", + DeprecationWarning, + ) + if postprocessing_vectorize is not None: import warnings @@ -585,15 +590,24 @@ def sample_jax_nuts( else: filtered_var_names = model.unobserved_value_vars - if nuts_kwargs is None: - nuts_kwargs = {} - else: - nuts_kwargs = nuts_kwargs.copy() + nuts_kwargs = {} if nuts_kwargs is None else nuts_kwargs.copy() + idata_kwargs = {} if idata_kwargs is None else idata_kwargs.copy() vars_to_sample = list( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) + log_likelihood_fn = partial(_get_log_likelihood, model) + transform_jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) + + def postprocess_fn(samples): + mcmc_samples, likelihoods = None, None + if idata_kwargs.pop("log_likelihood", False): + likelihoods = log_likelihood_fn(samples) + result = jax.vmap(transform_jax_fn)(*samples) + mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} + return mcmc_samples, likelihoods + (random_seed,) = _get_seeds_per_chain(random_seed, 1) initial_points = _get_batched_jittered_initial_points( @@ -611,8 +625,16 @@ def sample_jax_nuts( else: raise ValueError(f"{nuts_sampler=} not recognized") + current_backend = jax.default_backend() + if postprocessing_backend is not None and current_backend != postprocessing_backend: + + def process_fn(x): + return x, None + else: + process_fn = postprocess_fn + tic1 = datetime.now() - raw_mcmc_samples, sample_stats, library = sampler_fn( + mcmc_samples, sample_stats, log_likelihood, library = sampler_fn( model=model, target_accept=target_accept, tune=tune, @@ -622,35 +644,19 @@ def sample_jax_nuts( progressbar=progressbar, random_seed=random_seed, initial_points=initial_points, + postprocess_fn=process_fn, nuts_kwargs=nuts_kwargs, + num_chunks=num_chunks, ) - tic2 = datetime.now() - if idata_kwargs is None: - idata_kwargs = {} - else: - idata_kwargs = idata_kwargs.copy() - - if idata_kwargs.pop("log_likelihood", False): - log_likelihood = _get_log_likelihood( - model, - raw_mcmc_samples, - backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, + if postprocessing_backend is not None and current_backend != postprocessing_backend: + mcmc_samples = jax.device_put(mcmc_samples, jax.devices(postprocessing_backend)[0]) + sample_stats = jax.device_put(sample_stats, jax.devices(postprocessing_backend)[0]) + mcmc_samples, log_likelihood = jax.jit(jax.vmap(postprocess_fn), donate_argnums=0)( + mcmc_samples ) - else: - log_likelihood = None - - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - donate_samples=True, - ) - del raw_mcmc_samples - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} + + tic2 = datetime.now() attrs = { "sampling_time": (tic2 - tic1).total_seconds(), diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 5f32e1075..30000f859 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -229,7 +229,7 @@ def test_get_log_likelihood(): b_true = trace.log_likelihood.b.values a = np.array(trace.posterior.a) sigma_log_ = np.log(np.array(trace.posterior.sigma)) - b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] + b_jax = jax.vmap(_get_log_likelihood, in_axes=(None, 0))(model, [a, sigma_log_])["b"] assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 4c594a2b6..364d8e0f8 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np -import numpy.testing as npt import pytest from pymc import Data, Model, Normal, sample @@ -74,14 +73,32 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() -def test_step_args(): - with Model() as model: - a = Normal("a") - idata = sample( - nuts_sampler="numpyro", - target_accept=0.5, - nuts={"max_treedepth": 10}, - random_seed=1410, +def test_numpyro_external_nuts_chunking(): + # chunked sampling should give exact same results as non-chunked + nuts_sampler = "numpyro" + pytest.importorskip(nuts_sampler) + + with Model(): + x = Normal("x", 100, 5) + y = Data("y", [1, 2, 3, 4]) + + Normal("L", mu=x, sigma=0.1, observed=y) + + base_kwargs = dict( + nuts_sampler=nuts_sampler, + random_seed=123, + chains=2, + tune=500, + draws=500, + progressbar=False, + initvals={"x": 0.0}, + idata_kwargs={"log_likelihood": True}, ) + chunk_kwargs = {**base_kwargs, **{"nuts_sampler_kwargs": {"num_chunks": 10}}} - npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + idata1 = sample(**base_kwargs) + idata2 = sample(**chunk_kwargs) + + np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) + assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys()