From ebe4dee54bf3ccf91ea20d6e2a5283d0bcb8f8f8 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Mon, 12 Aug 2024 09:58:14 -0700 Subject: [PATCH 01/17] chunking jax samplers --- pymc/sampling/jax.py | 403 ++++++++++++++++++++----------------------- 1 file changed, 186 insertions(+), 217 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c530af8d9..6deadc7f9 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -27,7 +27,6 @@ import pytensor.tensor as pt from arviz.data.base import make_attrs -from jax.lax import scan from pytensor.compile import SharedVariable, Supervisor, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph @@ -158,54 +157,14 @@ def logp_fn_wrap(x): return logp_fn_wrap -def _get_log_likelihood( - model: Model, - samples, - backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "scan", -) -> dict: +def _get_log_likelihood_fn(model: Model) -> Callable: """Compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) - result = _postprocess_samples( - jax_fn, - samples, - backend, - postprocessing_vectorize=postprocessing_vectorize, - donate_samples=False, - ) - return {v.name: r for v, r in zip(model.observed_RVs, result)} - - -def _device_put(input, device: str): - return jax.device_put(input, jax.devices(device)[0]) - - -def _postprocess_samples( - jax_fn: Callable, - raw_mcmc_samples: list[TensorVariable], - postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] = "vmap", - donate_samples: bool = False, -) -> list[TensorVariable]: - if postprocessing_vectorize == "scan": - t_raw_mcmc_samples = [jnp.swapaxes(t, 0, 1) for t in raw_mcmc_samples] - jax_vfn = jax.vmap(jax_fn) - _, outs = scan( - lambda _, x: ((), jax_vfn(*x)), - (), - _device_put(t_raw_mcmc_samples, postprocessing_backend), - ) - return [jnp.swapaxes(t, 0, 1) for t in outs] - elif postprocessing_vectorize == "vmap": - - def process_fn(x): - return jax.vmap(jax.vmap(jax_fn))(*_device_put(x, postprocessing_backend)) - - return jax.jit(process_fn, donate_argnums=0 if donate_samples else None)(raw_mcmc_samples) - - else: - raise ValueError(f"Unrecognized postprocessing_vectorize: {postprocessing_vectorize}") + def log_likelihood_fn(samples): + result = jax.vmap(jax_fn)(*samples) + return {v.name: r for v, r in zip(model.observed_RVs, result)} + return log_likelihood_fn def _get_batched_jittered_initial_points( @@ -238,52 +197,38 @@ def _get_batched_jittered_initial_points( return [np.stack(init_state) for init_state in zip(*initial_points_values)] -def _blackjax_inference_loop( - seed, init_position, logprob_fn, draws, tune, target_accept, **adaptation_kwargs -): - import blackjax +@partial(jax.jit, donate_argnums=0) +def _set_tree(store, input, idx): + def update_fn(save, inp): + starts = (idx, *([0] * (len(save.shape) - 1))) + return jax.lax.dynamic_update_slice(save, inp, starts) - from blackjax.adaptation.base import get_filter_adapt_info_fn + store = jax.tree.map(update_fn, store, input) + return store - algorithm_name = adaptation_kwargs.pop("algorithm", "nuts") - if algorithm_name == "nuts": - algorithm = blackjax.nuts - elif algorithm_name == "hmc": - algorithm = blackjax.hmc - else: - raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.") - adapt = blackjax.window_adaptation( - algorithm=algorithm, - logdensity_fn=logprob_fn, - target_acceptance_rate=target_accept, - adaptation_info_fn=get_filter_adapt_info_fn(), - **adaptation_kwargs, - ) - (last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune) - kernel = algorithm(logprob_fn, **tuned_params).step +def _gen_arr(inp, nchunk): + shape = (inp.shape[0] * nchunk, *inp.shape[1:]) + return jnp.zeros(shape, dtype=inp.dtype, device=jax.devices("cpu")[0]) - def _one_step(state, xs): - _, rng_key = xs - state, info = kernel(rng_key, state) - position = state.position - stats = { - "diverging": info.is_divergent, - "energy": info.energy, - "tree_depth": info.num_trajectory_expansions, - "n_steps": info.num_integration_steps, - "acceptance_rate": info.acceptance_rate, - "lp": state.logdensity, - } - return state, (position, stats) - progress_bar = adaptation_kwargs.pop("progress_bar", False) - - keys = jax.random.split(seed, draws) - scan_fn = blackjax.progress_bar.gen_scan_fn(draws, progress_bar) - _, (samples, stats) = scan_fn(_one_step, last_state, (jnp.arange(draws), keys)) +def _do_chunked_sampling(last_state, tmpout, nchunk, nsteps, sample_fn, progressbar): + output = _set_tree( + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=nchunk)), tmpout), + jax.device_put(tmpout, jax.devices("cpu")[0]), + 0 + ) - return samples, stats + for i in range(1, nchunk): + if progressbar: + print("Sampling chunk %d of %d:" % (i+1, nchunk)) + last_state, tmpout = sample_fn(last_state) + output = _set_tree( + output, + jax.device_put(tmpout, jax.devices("cpu")[0]), + nsteps * i, + ) + return last_state, output def _sample_blackjax_nuts( @@ -296,72 +241,19 @@ def _sample_blackjax_nuts( progressbar: bool, random_seed: int, initial_points, + postprocess_fn, nuts_kwargs, + num_chunks: int = 1, ) -> az.InferenceData: - """ - Draw samples from the posterior using the NUTS method from the ``blackjax`` library. - - Parameters - ---------- - draws : int, default 1000 - The number of samples to draw. The number of tuned samples are discarded by - default. - tune : int, default 1000 - Number of iterations to tune. Samplers adjust the step sizes, scalings or - similar during tuning. Tuning samples will be drawn in addition to the number - specified in the ``draws`` argument. - chains : int, default 4 - The number of chains to sample. - target_accept : float in [0, 1]. - The step size is tuned such that we approximate this acceptance rate. Higher - values like 0.9 or 0.95 often work better for problematic posteriors. - random_seed : int, RandomState or Generator, optional - Random seed used by the sampling steps. - initvals: StartDict or Sequence[Optional[StartDict]], optional - Initial values for random variables provided as a dictionary (or sequence of - dictionaries) mapping the random variable (by name or reference) to desired - starting values. - jitter: bool, default True - If True, add jitter to initial points. - model : Model, optional - Model to sample from. The model needs to have free random variables. When inside - a ``with`` model context, it defaults to that model, otherwise the model must be - passed explicitly. - var_names : sequence of str, optional - Names of variables for which to compute the posterior samples. Defaults to all - variables in the posterior. - keep_untransformed : bool, default False - Include untransformed variables in the posterior samples. Defaults to False. - chain_method : str, default "parallel" - Specify how samples should be drawn. The choices include "parallel", and - "vectorized". - postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None, - Specify how postprocessing should be computed. gpu or cpu - postprocessing_vectorize: Literal["vmap", "scan"], default "scan" - How to vectorize the postprocessing: vmap or sequential scan - idata_kwargs : dict, optional - Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as - value for the ``log_likelihood`` key to indicate that the pointwise log - likelihood should not be included in the returned object. Values for - ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from - the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and - ``dims`` are provided, they are used to update the inferred dictionaries. - - Returns - ------- - InferenceData - ArviZ ``InferenceData`` object that contains the posterior samples, together - with their respective sample stats and pointwise log likeihood values (unless - skipped with ``idata_kwargs``). - """ import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap elif chain_method == "vectorized": - map_fn = jax.vmap + map_fn = lambda x: jax.jit(jax.vmap(x)) else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' @@ -372,41 +264,96 @@ def _sample_blackjax_nuts( logprob_fn = get_jaxified_logp(model) - seed = jax.random.PRNGKey(random_seed) - keys = jax.random.split(seed, chains) + s1, s2 = jax.random.split(jax.random.PRNGKey(random_seed)) + adapt_seed = jax.random.split(s1, chains) + sample_seed = jax.random.split(s2, chains) + del s1, s2 - nuts_kwargs["progress_bar"] = progressbar - get_posterior_samples = partial( - _blackjax_inference_loop, - logprob_fn=logprob_fn, - tune=tune, - draws=draws, - target_accept=target_accept, + + algorithm_name = nuts_kwargs.pop("algorithm", "nuts") + if algorithm_name == "nuts": + algorithm = blackjax.nuts + elif algorithm_name == "hmc": + algorithm = blackjax.hmc + else: + raise ValueError("Only supporting 'nuts' or 'hmc' as algorithm to draw samples.") + + assert draws % num_chunks == 0 + nsteps = draws // num_chunks + + # Run adaptation + adapt = blackjax.window_adaptation( + algorithm=algorithm, + logdensity_fn=logprob_fn, + target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), + progress_bar=progressbar, **nuts_kwargs, ) - raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points) - return raw_mcmc_samples, sample_stats, blackjax + @map_fn + def run_adaptation(seed, init_position): + return adapt.run(seed, init_position, num_steps=tune) + (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) + del adapt_seed + + def _one_step(state, x, imm, ss): + _, rng_key = x + state, info = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step( + rng_key, state + ) + position = state.position + stats = { + "diverging": info.is_divergent, + "energy": info.energy, + "tree_depth": info.num_trajectory_expansions, + "n_steps": info.num_integration_steps, + "acceptance_rate": info.acceptance_rate, + "lp": state.logdensity, + } + return state, (position, stats) + + @map_fn + def _multi_step(state, imm, ss): + start_state, key = state + key, _skey = jax.random.split(key) + _skeys = jax.random.split(_skey, nsteps) + + scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) + + last_state, (raw_samples, stats) = scan_fn( + partial(_one_step, imm=imm, ss=ss), start_state, (jnp.arange(nsteps), _skeys) + ) + samples, log_likelihoods = postprocess_fn(raw_samples) + return (last_state, key), ((samples, log_likelihoods), stats) + + sample_fn = partial(_multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"]) + + if progressbar and num_chunks > 1: + print("Sampling chunk %d of %d:" % (1, num_chunks)) + elif progressbar: + print("Sampling:") + (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) + del sample_seed + if num_chunks == 1: + return samples[0], stats, samples[1], blackjax + + last_state, (all_samples, all_stats) = _do_chunked_sampling((last_state, seed), (samples, stats), num_chunks, nsteps, sample_fn, progressbar) + return all_samples[0], all_stats, all_samples[1], blackjax -# Adopted from arviz numpyro extractor def _numpyro_stats_to_dict(posterior): """Extract sample_stats from NumPyro posterior.""" - rename_key = { - "potential_energy": "lp", - "adapt_state.step_size": "step_size", - "num_steps": "n_steps", - "accept_prob": "acceptance_rate", + extra_fields = posterior.get_extra_fields(group_by_chain=True) + data = { + "lp": extra_fields["potential_energy"], + "step_size": extra_fields["adapt_state.step_size"], + "n_steps": extra_fields["num_steps"], + "acceptance_rate": extra_fields["accept_prob"], + "tree_depth": jnp.log2(extra_fields["num_steps"]).astype(int) + 1, + "energy": extra_fields["energy"], + "diverging": extra_fields["diverging"], } - data = {} - for stat, value in posterior.get_extra_fields(group_by_chain=True).items(): - if isinstance(value, dict | tuple): - continue - name = rename_key.get(stat, stat) - value = value.copy() - data[name] = value - if stat == "num_steps": - data["tree_depth"] = np.log2(value).astype(int) + 1 return data @@ -420,12 +367,17 @@ def _sample_numpyro_nuts( progressbar: bool, random_seed: int, initial_points, + postprocess_fn, nuts_kwargs: dict[str, Any], + num_chunks: int = 1, ): - import numpyro + import numpyro from numpyro.infer import MCMC, NUTS + assert draws % num_chunks == 0 + nsteps = draws // num_chunks + logp_fn = get_jaxified_logp(model, negative_logp=False) nuts_kwargs.setdefault("adapt_step_size", True) @@ -441,33 +393,55 @@ def _sample_numpyro_nuts( pmap_numpyro = MCMC( nuts_kernel, num_warmup=tune, - num_samples=draws, + num_samples=nsteps, num_chains=chains, postprocess_fn=None, chain_method=chain_method, progress_bar=progressbar, ) - map_seed = jax.random.PRNGKey(random_seed) - if chains > 1: - map_seed = jax.random.split(map_seed, chains) - - pmap_numpyro.run( - map_seed, - init_params=initial_points, - extra_fields=( + extra_fields = ( "num_steps", "potential_energy", "energy", "adapt_state.step_size", "accept_prob", "diverging", - ), - ) + ) + vmap_postprocess = jax.jit(jax.vmap(postprocess_fn)) + + key = jax.random.PRNGKey(random_seed) + del random_seed + key, _skey = jax.random.split(key) + if progressbar and num_chunks > 1: + print("Sampling chunk %d of %d:" % (1, num_chunks)) + elif progressbar: + print("Sampling:") + pmap_numpyro.run(_skey, init_params=initial_points, extra_fields=extra_fields) + del _skey raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) - sample_stats = _numpyro_stats_to_dict(pmap_numpyro) - return raw_mcmc_samples, sample_stats, numpyro + stats = _numpyro_stats_to_dict(pmap_numpyro) + samples = vmap_postprocess(raw_mcmc_samples) + + if num_chunks == 1: + return samples[0], stats, samples[1], numpyro + + def sample_chunk(state): + pmap_numpyro.post_warmup_state, key = state + key, _skey = jax.random.split(key) + pmap_numpyro.run(_skey, extra_fields=extra_fields) + del _skey + + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) + sample_stats = _numpyro_stats_to_dict(pmap_numpyro) + mcmc_samples, likelihoods = vmap_postprocess(raw_mcmc_samples) + return (pmap_numpyro.last_state, key), ((mcmc_samples, likelihoods), sample_stats) + + + _, (all_samples, all_stats) = _do_chunked_sampling((pmap_numpyro.last_state, key), (samples, stats), num_chunks, nsteps, sample_chunk, progressbar) + + return all_samples[0], all_stats, all_samples[1], numpyro def sample_jax_nuts( @@ -485,12 +459,13 @@ def sample_jax_nuts( progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] | None = None, - postprocessing_chunks=None, + postprocessing_backend: Literal["cpu", "gpu"] | None = None, #Note unused + postprocessing_vectorize: Literal["vmap", "scan"] | None = None, #Note unused + postprocessing_chunks=None, #Note unused idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, nuts_sampler: Literal["numpyro", "blackjax"], + num_chunks: int = 1, ) -> az.InferenceData: """ Draw samples from the posterior using a jax NUTS method. @@ -568,6 +543,15 @@ def sample_jax_nuts( DeprecationWarning, ) + if postprocessing_backend is not None: + import warnings + + warnings.warn( + "postprocessing_backend={'cpu', 'gpu'} will be removed in a future release, " + "postprocessing is done on sampling device.", + DeprecationWarning, + ) + if postprocessing_vectorize is not None: import warnings @@ -585,15 +569,24 @@ def sample_jax_nuts( else: filtered_var_names = model.unobserved_value_vars - if nuts_kwargs is None: - nuts_kwargs = {} - else: - nuts_kwargs = nuts_kwargs.copy() + nuts_kwargs = {} if nuts_kwargs is None else nuts_kwargs.copy() + idata_kwargs = {} if idata_kwargs is None else idata_kwargs.copy() vars_to_sample = list( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) + log_likelihood_fn = _get_log_likelihood_fn(model) + transform_jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) + + def postprocess_fn(samples): + mcmc_samples, likelihoods = None, None + if idata_kwargs.pop("log_likelihood", False): + likelihoods = log_likelihood_fn(samples) + result = jax.vmap(transform_jax_fn)(*samples) + mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} + return mcmc_samples, likelihoods + (random_seed,) = _get_seeds_per_chain(random_seed, 1) initial_points = _get_batched_jittered_initial_points( @@ -612,7 +605,7 @@ def sample_jax_nuts( raise ValueError(f"{nuts_sampler=} not recognized") tic1 = datetime.now() - raw_mcmc_samples, sample_stats, library = sampler_fn( + mcmc_samples, sample_stats, log_likelihood, library = sampler_fn( model=model, target_accept=target_accept, tune=tune, @@ -622,36 +615,12 @@ def sample_jax_nuts( progressbar=progressbar, random_seed=random_seed, initial_points=initial_points, + postprocess_fn=postprocess_fn, nuts_kwargs=nuts_kwargs, + num_chunks=num_chunks, ) tic2 = datetime.now() - if idata_kwargs is None: - idata_kwargs = {} - else: - idata_kwargs = idata_kwargs.copy() - - if idata_kwargs.pop("log_likelihood", False): - log_likelihood = _get_log_likelihood( - model, - raw_mcmc_samples, - backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - ) - else: - log_likelihood = None - - jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) - result = _postprocess_samples( - jax_fn, - raw_mcmc_samples, - postprocessing_backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, - donate_samples=True, - ) - del raw_mcmc_samples - mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)} - attrs = { "sampling_time": (tic2 - tic1).total_seconds(), "tuning_steps": tune, From e364b1cb60c7221d0d12b55de9ca06bc467ee8fc Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Mon, 12 Aug 2024 10:01:10 -0700 Subject: [PATCH 02/17] fix print --- pymc/sampling/jax.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 6deadc7f9..e44769bc1 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -329,10 +329,8 @@ def _multi_step(state, imm, ss): sample_fn = partial(_multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"]) - if progressbar and num_chunks > 1: + if progressbar: print("Sampling chunk %d of %d:" % (1, num_chunks)) - elif progressbar: - print("Sampling:") (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) del sample_seed if num_chunks == 1: @@ -414,10 +412,8 @@ def _sample_numpyro_nuts( key = jax.random.PRNGKey(random_seed) del random_seed key, _skey = jax.random.split(key) - if progressbar and num_chunks > 1: + if progressbar: print("Sampling chunk %d of %d:" % (1, num_chunks)) - elif progressbar: - print("Sampling:") pmap_numpyro.run(_skey, init_params=initial_points, extra_fields=extra_fields) del _skey raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) From c771598a43add8fbce789ec51241536dff753e0c Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Mon, 12 Aug 2024 22:04:39 -0700 Subject: [PATCH 03/17] add test --- pymc/sampling/jax.py | 76 ++++++++++++++++------------ tests/sampling/test_mcmc_external.py | 15 +++++- 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index e44769bc1..f68bb51cf 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -161,9 +161,11 @@ def _get_log_likelihood_fn(model: Model) -> Callable: """Compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) + def log_likelihood_fn(samples): result = jax.vmap(jax_fn)(*samples) return {v.name: r for v, r in zip(model.observed_RVs, result)} + return log_likelihood_fn @@ -200,7 +202,7 @@ def _get_batched_jittered_initial_points( @partial(jax.jit, donate_argnums=0) def _set_tree(store, input, idx): def update_fn(save, inp): - starts = (idx, *([0] * (len(save.shape) - 1))) + starts = (save.shape[0], idx, *([0] * (len(save.shape) - 2))) return jax.lax.dynamic_update_slice(save, inp, starts) store = jax.tree.map(update_fn, store, input) @@ -216,12 +218,12 @@ def _do_chunked_sampling(last_state, tmpout, nchunk, nsteps, sample_fn, progress output = _set_tree( jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=nchunk)), tmpout), jax.device_put(tmpout, jax.devices("cpu")[0]), - 0 + 0, ) for i in range(1, nchunk): if progressbar: - print("Sampling chunk %d of %d:" % (i+1, nchunk)) + logger.info("Sampling chunk %d of %d:" % (i + 1, nchunk)) last_state, tmpout = sample_fn(last_state) output = _set_tree( output, @@ -245,15 +247,17 @@ def _sample_blackjax_nuts( nuts_kwargs, num_chunks: int = 1, ) -> az.InferenceData: - import blackjax + from blackjax.adaptation.base import get_filter_adapt_info_fn # Adapted from numpyro if chain_method == "parallel": map_fn = jax.pmap elif chain_method == "vectorized": - map_fn = lambda x: jax.jit(jax.vmap(x)) + + def map_fn(x): + return jax.jit(jax.vmap(x)) else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' @@ -269,7 +273,6 @@ def _sample_blackjax_nuts( sample_seed = jax.random.split(s2, chains) del s1, s2 - algorithm_name = nuts_kwargs.pop("algorithm", "nuts") if algorithm_name == "nuts": algorithm = blackjax.nuts @@ -294,6 +297,7 @@ def _sample_blackjax_nuts( @map_fn def run_adaptation(seed, init_position): return adapt.run(seed, init_position, num_steps=tune) + (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) del adapt_seed @@ -327,16 +331,20 @@ def _multi_step(state, imm, ss): samples, log_likelihoods = postprocess_fn(raw_samples) return (last_state, key), ((samples, log_likelihoods), stats) - sample_fn = partial(_multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"]) + sample_fn = partial( + _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] + ) if progressbar: - print("Sampling chunk %d of %d:" % (1, num_chunks)) + logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) del sample_seed if num_chunks == 1: return samples[0], stats, samples[1], blackjax - - last_state, (all_samples, all_stats) = _do_chunked_sampling((last_state, seed), (samples, stats), num_chunks, nsteps, sample_fn, progressbar) + + last_state, (all_samples, all_stats) = _do_chunked_sampling( + (last_state, seed), (samples, stats), num_chunks, nsteps, sample_fn, progressbar + ) return all_samples[0], all_stats, all_samples[1], blackjax @@ -344,13 +352,13 @@ def _numpyro_stats_to_dict(posterior): """Extract sample_stats from NumPyro posterior.""" extra_fields = posterior.get_extra_fields(group_by_chain=True) data = { - "lp": extra_fields["potential_energy"], - "step_size": extra_fields["adapt_state.step_size"], - "n_steps": extra_fields["num_steps"], - "acceptance_rate": extra_fields["accept_prob"], - "tree_depth": jnp.log2(extra_fields["num_steps"]).astype(int) + 1, - "energy": extra_fields["energy"], - "diverging": extra_fields["diverging"], + "lp": extra_fields["potential_energy"], + "step_size": extra_fields["adapt_state.step_size"], + "n_steps": extra_fields["num_steps"], + "acceptance_rate": extra_fields["accept_prob"], + "tree_depth": jnp.log2(extra_fields["num_steps"]).astype(int) + 1, + "energy": extra_fields["energy"], + "diverging": extra_fields["diverging"], } return data @@ -369,8 +377,8 @@ def _sample_numpyro_nuts( nuts_kwargs: dict[str, Any], num_chunks: int = 1, ): - import numpyro + from numpyro.infer import MCMC, NUTS assert draws % num_chunks == 0 @@ -399,13 +407,13 @@ def _sample_numpyro_nuts( ) extra_fields = ( - "num_steps", - "potential_energy", - "energy", - "adapt_state.step_size", - "accept_prob", - "diverging", - ) + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", + ) vmap_postprocess = jax.jit(jax.vmap(postprocess_fn)) @@ -413,7 +421,7 @@ def _sample_numpyro_nuts( del random_seed key, _skey = jax.random.split(key) if progressbar: - print("Sampling chunk %d of %d:" % (1, num_chunks)) + logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) pmap_numpyro.run(_skey, init_params=initial_points, extra_fields=extra_fields) del _skey raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) @@ -433,9 +441,15 @@ def sample_chunk(state): sample_stats = _numpyro_stats_to_dict(pmap_numpyro) mcmc_samples, likelihoods = vmap_postprocess(raw_mcmc_samples) return (pmap_numpyro.last_state, key), ((mcmc_samples, likelihoods), sample_stats) - - _, (all_samples, all_stats) = _do_chunked_sampling((pmap_numpyro.last_state, key), (samples, stats), num_chunks, nsteps, sample_chunk, progressbar) + _, (all_samples, all_stats) = _do_chunked_sampling( + (pmap_numpyro.last_state, key), + (samples, stats), + num_chunks, + nsteps, + sample_chunk, + progressbar, + ) return all_samples[0], all_stats, all_samples[1], numpyro @@ -455,9 +469,9 @@ def sample_jax_nuts( progressbar: bool = True, keep_untransformed: bool = False, chain_method: str = "parallel", - postprocessing_backend: Literal["cpu", "gpu"] | None = None, #Note unused - postprocessing_vectorize: Literal["vmap", "scan"] | None = None, #Note unused - postprocessing_chunks=None, #Note unused + postprocessing_backend: Literal["cpu", "gpu"] | None = None, # Note unused + postprocessing_vectorize: Literal["vmap", "scan"] | None = None, # Note unused + postprocessing_chunks=None, # Note unused idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, nuts_sampler: Literal["numpyro", "blackjax"], diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 4c594a2b6..c805033bf 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -19,8 +19,18 @@ from pymc import Data, Model, Normal, sample -@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) -def test_external_nuts_sampler(recwarn, nuts_sampler): +@pytest.mark.parametrize( + "nuts_sampler,nuts_kwargs", + [ + ("pymc", {}), + ("nutpie", {}), + ("blackjax", {}), + ("numpyro", {}), + ("blackjax", {"num_chunks": 10}), + ("numpyro", {"num_chunks": 10}), + ], +) +def test_external_nuts_sampler(recwarn, nuts_sampler, nuts_kwargs): if nuts_sampler != "pymc": pytest.importorskip(nuts_sampler) @@ -39,6 +49,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): draws=500, progressbar=False, initvals={"x": 0.0}, + nuts_sampler_kwargs=nuts_kwargs, ) idata1 = sample(**kwargs) From 44dfc88c7dad3b71658f4bece8c5197f51d8e4b7 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 13 Aug 2024 15:33:40 -0700 Subject: [PATCH 04/17] add test --- pymc/sampling/jax.py | 15 +++++++------- tests/sampling/test_jax.py | 4 ++-- tests/sampling/test_mcmc_external.py | 29 ++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index f68bb51cf..13e84e3bb 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -302,9 +302,11 @@ def run_adaptation(seed, init_position): del adapt_seed def _one_step(state, x, imm, ss): - _, rng_key = x + del x + state, rng_key = state + key, _skey = jax.random.split(rng_key) state, info = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step( - rng_key, state + _skey, state ) position = state.position stats = { @@ -315,18 +317,15 @@ def _one_step(state, x, imm, ss): "acceptance_rate": info.acceptance_rate, "lp": state.logdensity, } - return state, (position, stats) + return (state, key), (position, stats) @map_fn def _multi_step(state, imm, ss): start_state, key = state - key, _skey = jax.random.split(key) - _skeys = jax.random.split(_skey, nsteps) - scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) - last_state, (raw_samples, stats) = scan_fn( - partial(_one_step, imm=imm, ss=ss), start_state, (jnp.arange(nsteps), _skeys) + (last_state, key), (raw_samples, stats) = scan_fn( + partial(_one_step, imm=imm, ss=ss), (start_state, key), jnp.arange(nsteps) ) samples, log_likelihoods = postprocess_fn(raw_samples) return (last_state, key), ((samples, log_likelihoods), stats) diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 5f32e1075..1af7308a8 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -38,7 +38,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.sampling.jax import ( _get_batched_jittered_initial_points, - _get_log_likelihood, + _get_log_likelihood_fn, _replace_shared_variables, get_jaxified_graph, get_jaxified_logp, @@ -229,7 +229,7 @@ def test_get_log_likelihood(): b_true = trace.log_likelihood.b.values a = np.array(trace.posterior.a) sigma_log_ = np.log(np.array(trace.posterior.sigma)) - b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] + b_jax = jax.vmap(_get_log_likelihood_fn(model))([a, sigma_log_])["b"] assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index c805033bf..89fa8d30e 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -85,6 +85,35 @@ def test_external_nuts_sampler(recwarn, nuts_sampler, nuts_kwargs): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() +def test_blackjax_chunking(): + # blackjax should have same sampling whether chunked or not + nuts_sampler = "blackjax" + pytest.importorskip(nuts_sampler) + + with Model(): + x = Normal("x", 100, 5) + y = Data("y", [1, 2, 3, 4]) + + Normal("L", mu=x, sigma=0.1, observed=y) + + base_kwargs = dict( + nuts_sampler=nuts_sampler, + random_seed=123, + chains=2, + tune=500, + draws=500, + progressbar=False, + initvals={"x": 0.0}, + ) + chunk_kwargs = {**base_kwargs, **{"nuts_sampler_kwargs": {"num_chunks": 10}}} + + idata1 = sample(**base_kwargs) + idata2 = sample(**chunk_kwargs) + + np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() + + def test_step_args(): with Model() as model: a = Normal("a") From a1b6a9b13a629a6585bc13c85456fbd46520570b Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 13 Aug 2024 16:03:41 -0700 Subject: [PATCH 05/17] add numpyro test --- pymc/sampling/jax.py | 10 ++++------ tests/sampling/test_mcmc_external.py | 19 ++++--------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 13e84e3bb..33e8b4c5c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -431,18 +431,16 @@ def _sample_numpyro_nuts( return samples[0], stats, samples[1], numpyro def sample_chunk(state): - pmap_numpyro.post_warmup_state, key = state - key, _skey = jax.random.split(key) - pmap_numpyro.run(_skey, extra_fields=extra_fields) - del _skey + pmap_numpyro.post_warmup_state = state + pmap_numpyro.run(pmap_numpyro.post_warmup_state.rng_key, extra_fields=extra_fields) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) sample_stats = _numpyro_stats_to_dict(pmap_numpyro) mcmc_samples, likelihoods = vmap_postprocess(raw_mcmc_samples) - return (pmap_numpyro.last_state, key), ((mcmc_samples, likelihoods), sample_stats) + return pmap_numpyro.last_state, ((mcmc_samples, likelihoods), sample_stats) _, (all_samples, all_stats) = _do_chunked_sampling( - (pmap_numpyro.last_state, key), + pmap_numpyro.last_state, (samples, stats), num_chunks, nsteps, diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 89fa8d30e..d2ffc1b31 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -19,18 +19,8 @@ from pymc import Data, Model, Normal, sample -@pytest.mark.parametrize( - "nuts_sampler,nuts_kwargs", - [ - ("pymc", {}), - ("nutpie", {}), - ("blackjax", {}), - ("numpyro", {}), - ("blackjax", {"num_chunks": 10}), - ("numpyro", {"num_chunks": 10}), - ], -) -def test_external_nuts_sampler(recwarn, nuts_sampler, nuts_kwargs): +@pytest.mark.parametrize("nuts_sampler", ["pymc", "nutpie", "blackjax", "numpyro"]) +def test_external_nuts_sampler(recwarn, nuts_sampler): if nuts_sampler != "pymc": pytest.importorskip(nuts_sampler) @@ -49,7 +39,6 @@ def test_external_nuts_sampler(recwarn, nuts_sampler, nuts_kwargs): draws=500, progressbar=False, initvals={"x": 0.0}, - nuts_sampler_kwargs=nuts_kwargs, ) idata1 = sample(**kwargs) @@ -85,9 +74,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler, nuts_kwargs): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() -def test_blackjax_chunking(): +@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) +def test_external_nuts_chunking(nuts_sampler): # blackjax should have same sampling whether chunked or not - nuts_sampler = "blackjax" pytest.importorskip(nuts_sampler) with Model(): From ff32bac636b30788b027aad46b19993e7e87066f Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 13 Aug 2024 20:56:03 -0700 Subject: [PATCH 06/17] fix --- pymc/sampling/jax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 33e8b4c5c..aaf6c653f 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -301,13 +301,11 @@ def run_adaptation(seed, init_position): (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) del adapt_seed - def _one_step(state, x, imm, ss): + def _one_step(state, x, kernel): del x state, rng_key = state key, _skey = jax.random.split(rng_key) - state, info = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step( - _skey, state - ) + state, info = kernel(_skey, state) position = state.position stats = { "diverging": info.is_divergent, @@ -324,8 +322,10 @@ def _multi_step(state, imm, ss): start_state, key = state scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) + kernel = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step + (last_state, key), (raw_samples, stats) = scan_fn( - partial(_one_step, imm=imm, ss=ss), (start_state, key), jnp.arange(nsteps) + partial(_one_step, kernel=kernel), (start_state, key), jnp.arange(nsteps) ) samples, log_likelihoods = postprocess_fn(raw_samples) return (last_state, key), ((samples, log_likelihoods), stats) From 190ad0fb5520867f48e6d405cd3ccaa5786249e6 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Tue, 13 Aug 2024 21:12:56 -0700 Subject: [PATCH 07/17] remove overarching jit --- pymc/sampling/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index aaf6c653f..049124f6c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -257,7 +257,7 @@ def _sample_blackjax_nuts( elif chain_method == "vectorized": def map_fn(x): - return jax.jit(jax.vmap(x)) + return jax.vmap(x) #jitting here hurts memory performance else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' From 0ed006639e817d7ef360a828d572f88858d2d80a Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 14 Aug 2024 17:25:59 -0700 Subject: [PATCH 08/17] . --- pymc/sampling/jax.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 049124f6c..3026de108 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -214,13 +214,7 @@ def _gen_arr(inp, nchunk): return jnp.zeros(shape, dtype=inp.dtype, device=jax.devices("cpu")[0]) -def _do_chunked_sampling(last_state, tmpout, nchunk, nsteps, sample_fn, progressbar): - output = _set_tree( - jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=nchunk)), tmpout), - jax.device_put(tmpout, jax.devices("cpu")[0]), - 0, - ) - +def _do_chunked_sampling(last_state, output, nchunk, nsteps, sample_fn, progressbar): for i in range(1, nchunk): if progressbar: logger.info("Sampling chunk %d of %d:" % (i + 1, nchunk)) @@ -230,6 +224,7 @@ def _do_chunked_sampling(last_state, tmpout, nchunk, nsteps, sample_fn, progress jax.device_put(tmpout, jax.devices("cpu")[0]), nsteps * i, ) + del tmpout return last_state, output @@ -257,7 +252,7 @@ def _sample_blackjax_nuts( elif chain_method == "vectorized": def map_fn(x): - return jax.vmap(x) #jitting here hurts memory performance + return jax.jit(jax.vmap(x)) else: raise ValueError( "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"' @@ -271,7 +266,6 @@ def map_fn(x): s1, s2 = jax.random.split(jax.random.PRNGKey(random_seed)) adapt_seed = jax.random.split(s1, chains) sample_seed = jax.random.split(s2, chains) - del s1, s2 algorithm_name = nuts_kwargs.pop("algorithm", "nuts") if algorithm_name == "nuts": @@ -333,16 +327,22 @@ def _multi_step(state, imm, ss): sample_fn = partial( _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] ) - + sample_fn = jax.jit(sample_fn, donate_argnums=0) if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) - del sample_seed if num_chunks == 1: return samples[0], stats, samples[1], blackjax + output = _set_tree( + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)), + jax.device_put((samples, stats), jax.devices("cpu")[0]), + 0, + ) + del samples, stats + last_state, (all_samples, all_stats) = _do_chunked_sampling( - (last_state, seed), (samples, stats), num_chunks, nsteps, sample_fn, progressbar + (last_state, seed), output, num_chunks, nsteps, sample_fn, progressbar ) return all_samples[0], all_stats, all_samples[1], blackjax @@ -439,9 +439,16 @@ def sample_chunk(state): mcmc_samples, likelihoods = vmap_postprocess(raw_mcmc_samples) return pmap_numpyro.last_state, ((mcmc_samples, likelihoods), sample_stats) + output = _set_tree( + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)), + jax.device_put((samples, stats), jax.devices("cpu")[0]), + 0, + ) + del samples, stats + _, (all_samples, all_stats) = _do_chunked_sampling( pmap_numpyro.last_state, - (samples, stats), + output, num_chunks, nsteps, sample_chunk, From feb53b78c3149ca3957230965016425314884c72 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 15 Aug 2024 08:06:35 -0700 Subject: [PATCH 09/17] fix jit --- pymc/sampling/jax.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 3026de108..3aa4f0f18 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -312,6 +312,7 @@ def _one_step(state, x, kernel): return (state, key), (position, stats) @map_fn + @partial(jax.jit, donate_argnums=0) def _multi_step(state, imm, ss): start_state, key = state scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) @@ -327,7 +328,6 @@ def _multi_step(state, imm, ss): sample_fn = partial( _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] ) - sample_fn = jax.jit(sample_fn, donate_argnums=0) if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) @@ -377,7 +377,6 @@ def _sample_numpyro_nuts( num_chunks: int = 1, ): import numpyro - from numpyro.infer import MCMC, NUTS assert draws % num_chunks == 0 From 0afc9734c957de14336882bf96362696d3b3827a Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 15 Aug 2024 08:09:53 -0700 Subject: [PATCH 10/17] fix docstring --- pymc/sampling/jax.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 3aa4f0f18..8fd4ad0b8 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -539,6 +539,9 @@ def sample_jax_nuts( nuts_sampler : Literal["numpyro", "blackjax"] Nuts sampler library to use - do not change - use sample_numpyro_nuts or sample_blackjax_nuts as appropriate + num_chunks : int + Splits sampling into multiple chunks and collects them on the cpu. Reduces gpu memory + usage when sampling. There is no benefit when sampling on the cpu. Returns ------- From 270d0b3fa770fb69d79c7c605c6ba9552374fd21 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 15 Aug 2024 11:00:00 -0700 Subject: [PATCH 11/17] update docstrings --- pymc/sampling/jax.py | 9 +++++---- tests/sampling/test_mcmc_external.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 8fd4ad0b8..a5b7b0fba 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -158,7 +158,7 @@ def logp_fn_wrap(x): def _get_log_likelihood_fn(model: Model) -> Callable: - """Compute log-likelihood for all observations""" + """Generate function to compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) @@ -201,6 +201,7 @@ def _get_batched_jittered_initial_points( @partial(jax.jit, donate_argnums=0) def _set_tree(store, input, idx): + """Update pytree of outputs - used for saving results of chunked sampling""" def update_fn(save, inp): starts = (save.shape[0], idx, *([0] * (len(save.shape) - 2))) return jax.lax.dynamic_update_slice(save, inp, starts) @@ -210,11 +211,13 @@ def update_fn(save, inp): def _gen_arr(inp, nchunk): + """Generate output array on cpu for chunked sampling""" shape = (inp.shape[0] * nchunk, *inp.shape[1:]) return jnp.zeros(shape, dtype=inp.dtype, device=jax.devices("cpu")[0]) def _do_chunked_sampling(last_state, output, nchunk, nsteps, sample_fn, progressbar): + """Run chunked sampling saving to output on the cpu""" for i in range(1, nchunk): if progressbar: logger.info("Sampling chunk %d of %d:" % (i + 1, nchunk)) @@ -293,7 +296,6 @@ def run_adaptation(seed, init_position): return adapt.run(seed, init_position, num_steps=tune) (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) - del adapt_seed def _one_step(state, x, kernel): del x @@ -377,6 +379,7 @@ def _sample_numpyro_nuts( num_chunks: int = 1, ): import numpyro + from numpyro.infer import MCMC, NUTS assert draws % num_chunks == 0 @@ -416,12 +419,10 @@ def _sample_numpyro_nuts( vmap_postprocess = jax.jit(jax.vmap(postprocess_fn)) key = jax.random.PRNGKey(random_seed) - del random_seed key, _skey = jax.random.split(key) if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) pmap_numpyro.run(_skey, init_params=initial_points, extra_fields=extra_fields) - del _skey raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) stats = _numpyro_stats_to_dict(pmap_numpyro) samples = vmap_postprocess(raw_mcmc_samples) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index d2ffc1b31..c139c0da9 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -76,7 +76,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): @pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) def test_external_nuts_chunking(nuts_sampler): - # blackjax should have same sampling whether chunked or not + # chunked sampling should give exact same results as non-chunked pytest.importorskip(nuts_sampler) with Model(): From 9e2362e839181d172b20b480d389d95bb9a0eb2f Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Thu, 15 Aug 2024 11:00:36 -0700 Subject: [PATCH 12/17] fix lint --- pymc/sampling/jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index a5b7b0fba..0507440c5 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -202,6 +202,7 @@ def _get_batched_jittered_initial_points( @partial(jax.jit, donate_argnums=0) def _set_tree(store, input, idx): """Update pytree of outputs - used for saving results of chunked sampling""" + def update_fn(save, inp): starts = (save.shape[0], idx, *([0] * (len(save.shape) - 2))) return jax.lax.dynamic_update_slice(save, inp, starts) From 57e52e38798f5c2b30df270ba92810503bafd582 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Fri, 16 Aug 2024 12:42:49 -0700 Subject: [PATCH 13/17] enhance test --- tests/sampling/test_mcmc_external.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index c139c0da9..af4f04e2f 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -93,6 +93,7 @@ def test_external_nuts_chunking(nuts_sampler): draws=500, progressbar=False, initvals={"x": 0.0}, + idata_kwargs={"log_likelihood": True}, ) chunk_kwargs = {**base_kwargs, **{"nuts_sampler_kwargs": {"num_chunks": 10}}} @@ -100,6 +101,7 @@ def test_external_nuts_chunking(nuts_sampler): idata2 = sample(**chunk_kwargs) np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() From 641e14210a890e1af625b831a72c22b70abe7ac6 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Sat, 17 Aug 2024 18:00:26 -0700 Subject: [PATCH 14/17] fix _get_log_likelihood / numpyro rng --- pymc/sampling/jax.py | 15 +++++---------- tests/sampling/test_jax.py | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 0507440c5..b73ed0f51 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -157,16 +157,12 @@ def logp_fn_wrap(x): return logp_fn_wrap -def _get_log_likelihood_fn(model: Model) -> Callable: +def _get_log_likelihood(model: Model, samples) -> Callable: """Generate function to compute log-likelihood for all observations""" elemwise_logp = model.logp(model.observed_RVs, sum=False) jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=elemwise_logp) - - def log_likelihood_fn(samples): - result = jax.vmap(jax_fn)(*samples) - return {v.name: r for v, r in zip(model.observed_RVs, result)} - - return log_likelihood_fn + result = jax.vmap(jax_fn)(*samples) + return {v.name: r for v, r in zip(model.observed_RVs, result)} def _get_batched_jittered_initial_points( @@ -420,10 +416,9 @@ def _sample_numpyro_nuts( vmap_postprocess = jax.jit(jax.vmap(postprocess_fn)) key = jax.random.PRNGKey(random_seed) - key, _skey = jax.random.split(key) if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) - pmap_numpyro.run(_skey, init_params=initial_points, extra_fields=extra_fields) + pmap_numpyro.run(key, init_params=initial_points, extra_fields=extra_fields) raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) stats = _numpyro_stats_to_dict(pmap_numpyro) samples = vmap_postprocess(raw_mcmc_samples) @@ -594,7 +589,7 @@ def sample_jax_nuts( get_default_varnames(filtered_var_names, include_transformed=keep_untransformed) ) - log_likelihood_fn = _get_log_likelihood_fn(model) + log_likelihood_fn = partial(_get_log_likelihood, model) transform_jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) def postprocess_fn(samples): diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index 1af7308a8..30000f859 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -38,7 +38,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.sampling.jax import ( _get_batched_jittered_initial_points, - _get_log_likelihood_fn, + _get_log_likelihood, _replace_shared_variables, get_jaxified_graph, get_jaxified_logp, @@ -229,7 +229,7 @@ def test_get_log_likelihood(): b_true = trace.log_likelihood.b.values a = np.array(trace.posterior.a) sigma_log_ = np.log(np.array(trace.posterior.sigma)) - b_jax = jax.vmap(_get_log_likelihood_fn(model))([a, sigma_log_])["b"] + b_jax = jax.vmap(_get_log_likelihood, in_axes=(None, 0))(model, [a, sigma_log_])["b"] assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1)) From 9108c60d4ca945a963fa83ace095ea51e4af8541 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Sat, 17 Aug 2024 20:56:32 -0700 Subject: [PATCH 15/17] add postprocessing on different device --- pymc/sampling/jax.py | 21 ++++++++++++-- tests/sampling/test_mcmc_external.py | 41 ++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index b73ed0f51..516b631a5 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -561,7 +561,8 @@ def sample_jax_nuts( warnings.warn( "postprocessing_backend={'cpu', 'gpu'} will be removed in a future release, " - "postprocessing is done on sampling device.", + "postprocessing will be done on sampling device in the future. If device memory " + "consumption is an issue please use num_chunks to reduce consumption.", DeprecationWarning, ) @@ -617,6 +618,14 @@ def postprocess_fn(samples): else: raise ValueError(f"{nuts_sampler=} not recognized") + current_backend = jax.default_backend() + if postprocessing_backend is not None and current_backend != postprocessing_backend: + + def process_fn(x): + return x, None + else: + process_fn = postprocess_fn + tic1 = datetime.now() mcmc_samples, sample_stats, log_likelihood, library = sampler_fn( model=model, @@ -628,10 +637,18 @@ def postprocess_fn(samples): progressbar=progressbar, random_seed=random_seed, initial_points=initial_points, - postprocess_fn=postprocess_fn, + postprocess_fn=process_fn, nuts_kwargs=nuts_kwargs, num_chunks=num_chunks, ) + + if postprocessing_backend is not None and current_backend != postprocessing_backend: + mcmc_samples = jax.device_put(mcmc_samples, jax.devices(postprocessing_backend)[0]) + sample_stats = jax.device_put(sample_stats, jax.devices(postprocessing_backend)[0]) + mcmc_samples, log_likelihood = jax.jit(jax.vmap(postprocess_fn), donate_argnums=0)( + mcmc_samples + ) + tic2 = datetime.now() attrs = { diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index af4f04e2f..654e3c70c 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import jax import numpy as np import numpy.testing as npt import pytest @@ -116,3 +117,43 @@ def test_step_args(): ) npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) + + +@pytest.mark.skipif(jax.default_backend() == "cpu", reason="need default backend that is not cpu") +@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) +def test_postprocessing_backend(nuts_sampler): + pytest.importorskip(nuts_sampler) + default_backend = jax.default_backend() + + with Model(): + x = Normal("x", 100, 5) + y = Data("y", [1, 2, 3, 4]) + + Normal("L", mu=x, sigma=0.1, observed=y) + + base_kwargs = dict( + nuts_sampler=nuts_sampler, + random_seed=123, + chains=4, + tune=200, + draws=200, + progressbar=False, + initvals={"x": 0.0}, + idata_kwargs={"log_likelihood": True}, + ) + + idata1 = sample( + **base_kwargs, + nuts_sampler_kwargs={ + "postprocessing_backend": default_backend, + "chain_method": "vectorized", + }, + ) + idata2 = sample( + **base_kwargs, + nuts_sampler_kwargs={"postprocessing_backend": "cpu", "chain_method": "vectorized"}, + ) + + np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) + np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) + assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() From b34b4ccdb1fd641617604ef6e8ebb03ce32854a4 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Sun, 6 Oct 2024 22:09:36 -0600 Subject: [PATCH 16/17] switch to blackjax run_inference_algorithm --- pymc/sampling/jax.py | 79 +++++++++++++++------------- tests/sampling/test_mcmc_external.py | 57 +------------------- 2 files changed, 45 insertions(+), 91 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 516b631a5..c60fcac68 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -278,27 +278,22 @@ def map_fn(x): assert draws % num_chunks == 0 nsteps = draws // num_chunks - # Run adaptation - adapt = blackjax.window_adaptation( - algorithm=algorithm, - logdensity_fn=logprob_fn, - target_acceptance_rate=target_accept, - adaptation_info_fn=get_filter_adapt_info_fn(), - progress_bar=progressbar, - **nuts_kwargs, - ) - + # Run adaptation for sampling parameters @map_fn def run_adaptation(seed, init_position): - return adapt.run(seed, init_position, num_steps=tune) - - (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) - - def _one_step(state, x, kernel): - del x - state, rng_key = state - key, _skey = jax.random.split(rng_key) - state, info = kernel(_skey, state) + return blackjax.window_adaptation( + algorithm=algorithm, + logdensity_fn=logprob_fn, + target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), + progress_bar=progressbar, + **nuts_kwargs, + ).run(seed, init_position, num_steps=tune) + + (adapt_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) + + # Filters output from each sampling step + def _transform_fn(state, info): position = state.position stats = { "diverging": info.is_divergent, @@ -308,42 +303,54 @@ def _one_step(state, x, kernel): "acceptance_rate": info.acceptance_rate, "lp": state.logdensity, } - return (state, key), (position, stats) + return position, stats + # Performs sampling for each chunk + # random keys are carried with state @map_fn @partial(jax.jit, donate_argnums=0) def _multi_step(state, imm, ss): - start_state, key = state - scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) - - kernel = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step - - (last_state, key), (raw_samples, stats) = scan_fn( - partial(_one_step, kernel=kernel), (start_state, key), jnp.arange(nsteps) + state, key = state + key, _skey = jax.random.split(key) + last_state, (raw_samples, stats) = blackjax.util.run_inference_algorithm( + _skey, + algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss), + num_steps=nsteps, + initial_state=state, + progress_bar=progressbar, + transform=_transform_fn, ) samples, log_likelihoods = postprocess_fn(raw_samples) return (last_state, key), ((samples, log_likelihoods), stats) - sample_fn = partial( + chunk_sample_fn = partial( _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] ) + if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) - (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) + + # Sample first chunk + last_state, sample_data = chunk_sample_fn((adapt_state, sample_seed)) + + # If single chunk sampling return results on device if num_chunks == 1: - return samples[0], stats, samples[1], blackjax + ((samples, log_likelihoods), stats) = sample_data + return samples, stats, log_likelihoods, blackjax + # Provision space for all samples on the cpu + save first chunk output = _set_tree( - jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)), - jax.device_put((samples, stats), jax.devices("cpu")[0]), + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), sample_data), + jax.device_put(sample_data, jax.devices("cpu")[0]), 0, ) - del samples, stats + del sample_data - last_state, (all_samples, all_stats) = _do_chunked_sampling( - (last_state, seed), output, num_chunks, nsteps, sample_fn, progressbar + # Sample remaining chunks + _, ((samples, log_likelihoods), stats) = _do_chunked_sampling( + last_state, output, num_chunks, nsteps, chunk_sample_fn, progressbar ) - return all_samples[0], all_stats, all_samples[1], blackjax + return samples, stats, log_likelihoods, blackjax def _numpyro_stats_to_dict(posterior): diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 654e3c70c..d1f4544f4 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -75,9 +75,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() -@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) -def test_external_nuts_chunking(nuts_sampler): +def test_numpyro_external_nuts_chunking(): # chunked sampling should give exact same results as non-chunked + nuts_sampler = "numpyro" pytest.importorskip(nuts_sampler) with Model(): @@ -104,56 +104,3 @@ def test_external_nuts_chunking(nuts_sampler): np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() - - -def test_step_args(): - with Model() as model: - a = Normal("a") - idata = sample( - nuts_sampler="numpyro", - target_accept=0.5, - nuts={"max_treedepth": 10}, - random_seed=1410, - ) - - npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) - - -@pytest.mark.skipif(jax.default_backend() == "cpu", reason="need default backend that is not cpu") -@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) -def test_postprocessing_backend(nuts_sampler): - pytest.importorskip(nuts_sampler) - default_backend = jax.default_backend() - - with Model(): - x = Normal("x", 100, 5) - y = Data("y", [1, 2, 3, 4]) - - Normal("L", mu=x, sigma=0.1, observed=y) - - base_kwargs = dict( - nuts_sampler=nuts_sampler, - random_seed=123, - chains=4, - tune=200, - draws=200, - progressbar=False, - initvals={"x": 0.0}, - idata_kwargs={"log_likelihood": True}, - ) - - idata1 = sample( - **base_kwargs, - nuts_sampler_kwargs={ - "postprocessing_backend": default_backend, - "chain_method": "vectorized", - }, - ) - idata2 = sample( - **base_kwargs, - nuts_sampler_kwargs={"postprocessing_backend": "cpu", "chain_method": "vectorized"}, - ) - - np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) - np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) - assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() From 4e6b1fab15f121119e0d38beb427d16bda8b372a Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Sun, 6 Oct 2024 22:43:48 -0600 Subject: [PATCH 17/17] fix lint --- tests/sampling/test_mcmc_external.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index d1f4544f4..364d8e0f8 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jax import numpy as np -import numpy.testing as npt import pytest from pymc import Data, Model, Normal, sample