From bfd3bdfa536b5745f7959334f0f127e60a857f8a Mon Sep 17 00:00:00 2001 From: colcarroll Date: Wed, 30 Apr 2025 15:09:08 -0400 Subject: [PATCH] Remove flowmc for now --- bayeux/_src/mcmc/flowmc.py | 252 ------------------------------------- bayeux/mcmc/__init__.py | 12 -- bayeux/tests/mcmc_test.py | 19 --- docs/inspecting.md | 12 -- pyproject.toml | 1 - 5 files changed, 296 deletions(-) delete mode 100644 bayeux/_src/mcmc/flowmc.py diff --git a/bayeux/_src/mcmc/flowmc.py b/bayeux/_src/mcmc/flowmc.py deleted file mode 100644 index 45e68fa..0000000 --- a/bayeux/_src/mcmc/flowmc.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2024 The bayeux Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright 2024 The bayeux Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""flowMC specific code.""" -import arviz as az -from bayeux._src import shared -from flowMC import Sampler -from flowMC.resource.nf_model import realNVP -from flowMC.resource.nf_model import rqSpline -from flowMC.resource.local_kernel import HMC -from flowMC.resource.local_kernel import MALA -import jax -import jax.numpy as jnp - - -_NF_MODELS = { - "real_nvp": realNVP.RealNVP, - "masked_coupling_rq_spline": rqSpline.MaskedCouplingRQSpline, -} - -_LOCAL_SAMPLERS = {"mala": MALA.MALA, "hmc": HMC.HMC} - - -def get_nf_model_kwargs(nf_model, n_features, kwargs): - """Sets defaults and merges user-provided adaptation keywords.""" - defaults = { - # RealNVP kwargs - "n_hidden": 100, - "n_layer": 10, - # MaskedCouplingRQSpline kwargs - "n_layers": 4, - "num_bins": 8, - "hidden_size": [64, 64], - "spline_range": (-10.0, 10.0), - "n_features": n_features, - } | kwargs - - nf_model_kwargs, nf_model_required = shared.get_default_signature( - nf_model) - nf_model_kwargs.update( - {k: defaults[k] for k in nf_model_required if k in defaults}) - nf_model_required.remove("key") - nf_model_required.remove("kwargs") - nf_model_required = nf_model_required - nf_model_kwargs.keys() - - if nf_model_required: - raise ValueError( - "Unexpected required arguments: " - f"{','.join(nf_model_required)}. Probably file a bug, but " - "you can try to manually supply them as keywords." - ) - nf_model_kwargs.update( - {k: defaults[k] for k in nf_model_kwargs if k in defaults}) - - return nf_model_kwargs - - -def get_local_sampler_kwargs(local_sampler, log_density, n_features, kwargs): - """Sets defaults and merges user-provided adaptation keywords.""" - - defaults = { - # HMC kwargs - "condition_matrix": jnp.eye(n_features), - "n_leapfrog": 10, - # Both - "step_size": 0.1, - "logpdf": log_density - } | kwargs - - sampler_kwargs, sampler_required = shared.get_default_signature( - local_sampler) - sampler_kwargs.setdefault("jit", True) - sampler_kwargs.update( - {k: defaults[k] for k in sampler_required if k in defaults}) - sampler_required = sampler_required - sampler_kwargs.keys() - sampler_kwargs.update( - {k: defaults[k] for k in sampler_kwargs if k in defaults}) - sampler_required = sampler_required - sampler_kwargs.keys() - - if sampler_required: - raise ValueError( - "Unexpected required arguments: " - f"{','.join(sampler_required)}. Probably file a bug, but " - "you can try to manually supply them as keywords." - ) - return sampler_kwargs - - -def get_sampler_kwargs(sampler, n_features, kwargs): - """Sets defaults and merges user-provided adaptation keywords.""" - # We support `num_chains` everywhere else, so support it here. - if "num_chains" in kwargs: - kwargs["n_chains"] = kwargs["num_chains"] - defaults = { - "n_loop_training": 5, - "n_loop_production": 5, - "n_local_steps": 50, - "n_global_steps": 50, - "n_chains": 20, - "n_epochs": 30, - "learning_rate": 0.01, - "max_samples": 10_000, - "momentum": 0.9, - "batch_size": 10_000, - "use_global": True, - "global_sampler": None, - "logging": True, - "keep_quantile": 0., - "local_autotune": None, - "train_thinning": 1, - "output_thinning": 1, - "n_sample_max": 10_000, - "precompile": False, - "verbose": False, - "n_dim": n_features, - "data": {}} | kwargs - sampler_kwargs, sampler_required = shared.get_default_signature(sampler) - sampler_kwargs.update( - {k: defaults[k] for k in sampler_required if k in defaults}) - sampler_required = (sampler_required - - {"nf_model", "local_sampler", "rng_key", "kwargs"}) - sampler_required = sampler_required - sampler_kwargs.keys() - - if sampler_required: - raise ValueError( - "Unexpected required arguments: " - f"{','.join(sampler_required)}. Probably file a bug, but " - "you can try to manually supply them as keywords." - ) - return defaults | sampler_kwargs - - -class _FlowMCSampler(shared.Base): - """Base class for flowmc samplers.""" - name: str = "" - nf_model: str = "" - local_sampler: str = "" - - def _get_aux(self): - flat, unflatten = jax.flatten_util.ravel_pytree(self.test_point) - - @jax.vmap - def flatten(pytree): - return jax.flatten_util.ravel_pytree(pytree)[0] - - constrained_log_density = self.constrained_log_density() - def log_density(x, _): - return constrained_log_density(unflatten(x)).squeeze() - - return log_density, flatten, unflatten, flat.shape[0] - - def get_kwargs(self, **kwargs): - nf_model = _NF_MODELS[self.nf_model] - local_sampler = _LOCAL_SAMPLERS[self.local_sampler] - log_density, flatten, unflatten, n_features = self._get_aux() - - nf_model_kwargs = get_nf_model_kwargs(nf_model, n_features, kwargs) - local_sampler_kwargs = get_local_sampler_kwargs( - local_sampler, log_density, n_features, kwargs) - sampler = Sampler.Sampler - sampler_kwargs = get_sampler_kwargs(sampler, n_features, kwargs) - extra_parameters = {"flatten": flatten, - "unflatten": unflatten, - "num_chains": sampler_kwargs["n_chains"], - "return_pytree": kwargs.get("return_pytree", False)} - - return {nf_model: nf_model_kwargs, - local_sampler: local_sampler_kwargs, - sampler: sampler_kwargs, - "extra_parameters": extra_parameters} - - def __call__(self, seed, **kwargs): - kwargs = self.get_kwargs(**kwargs) - extra_parameters = kwargs["extra_parameters"] - num_chains = extra_parameters["num_chains"] - init_key, nf_key, seed = jax.random.split(seed, 3) - initial_state = self.get_initial_state( - init_key, num_chains=num_chains) - initial_state = extra_parameters["flatten"](initial_state) - nf_model = _NF_MODELS[self.nf_model] - local_sampler = _LOCAL_SAMPLERS[self.local_sampler] - - model = nf_model(key=nf_key, **kwargs[nf_model]) - local_sampler = local_sampler(**kwargs[local_sampler]) - sampler = Sampler.Sampler - nf_sampler = sampler( - rng_key=seed, - local_sampler=local_sampler, - nf_model=model, - **kwargs[sampler]) - nf_sampler.sample(initial_state, {}) - chains, *_ = nf_sampler.get_sampler_state().values() - - unflatten = jax.vmap(jax.vmap(extra_parameters["unflatten"])) - pytree = self.transform_fn(unflatten(chains)) - if extra_parameters["return_pytree"]: - return pytree - else: - if hasattr(pytree, "_asdict"): - pytree = pytree._asdict() - elif not isinstance(pytree, dict): - pytree = {"var0": pytree} - return az.from_dict(posterior=pytree) - - -class RealNVPMALA(_FlowMCSampler): - name = "flowmc_realnvp_mala" - nf_model = "real_nvp" - local_sampler = "mala" - - -class RealNVPHMC(_FlowMCSampler): - name = "flowmc_realnvp_hmc" - nf_model = "real_nvp" - local_sampler = "hmc" - - -class MaskedCouplingRQSplineMALA(_FlowMCSampler): - name = "flowmc_rqspline_mala" - nf_model = "masked_coupling_rq_spline" - local_sampler = "mala" - - -class MaskedCouplingRQSplineHMC(_FlowMCSampler): - name = "flowmc_rqspline_hmc" - nf_model = "masked_coupling_rq_spline" - local_sampler = "hmc" diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index 80a5315..074344b 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -35,18 +35,6 @@ "NUTSblackjax", "HMC_Pathfinder_blackjax", "NUTS_Pathfinder_blackjax"]) -if importlib.util.find_spec("flowMC") is not None: - from bayeux._src.mcmc.flowmc import MaskedCouplingRQSplineHMC as MaskedCouplingRQSplineHMCflowmc - from bayeux._src.mcmc.flowmc import MaskedCouplingRQSplineMALA as MaskedCouplingRQSplineMALAflowmc - from bayeux._src.mcmc.flowmc import RealNVPHMC as RealNVPHMCflowmc - from bayeux._src.mcmc.flowmc import RealNVPMALA as RealNVPMALAflowmc - - __all__.extend([ - "MaskedCouplingRQSplineHMCflowmc", - "MaskedCouplingRQSplineMALAflowmc", - "RealNVPHMCflowmc", - "RealNVPMALAflowmc"]) - if importlib.util.find_spec("numpyro") is not None: from bayeux._src.mcmc.numpyro import HMC as HMCnumpyro from bayeux._src.mcmc.numpyro import NUTS as NUTSnumpyro diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index 25f6af3..d2136de 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -95,25 +95,6 @@ def test_return_pytree_tfp_nuts(): assert pytree["x"]["y"].shape == (10, 4) -@pytest.mark.skipif(importlib.util.find_spec("flowMC") is None, - reason="Test requires flowMC which is not installed") -def test_return_pytree_flowmc(): - model = bx.Model(log_density=lambda pt: -jnp.sum(pt["x"]["y"]**2), - test_point={"x": {"y": jnp.array([1., 1.])}}) - seed = jax.random.PRNGKey(0) - pytree = model.mcmc.flowmc_realnvp_mala( - seed=seed, - return_pytree=True, - n_chains=4, - n_local_steps=1, - n_global_steps=1, - n_loop_training=1, - n_loop_production=5, - ) - # 10 draws = (1 local + 1 global) * 5 loops - assert pytree["x"]["y"].shape == (4, 10, 2) - - @pytest.mark.skipif(importlib.util.find_spec("nutpie") is None, reason="Test requires nutpie which is not installed") def test_return_pytree_nutpie(): diff --git a/docs/inspecting.md b/docs/inspecting.md index 6641134..fba8fa8 100644 --- a/docs/inspecting.md +++ b/docs/inspecting.md @@ -70,10 +70,6 @@ normal_model.methods 'blackjax_nuts', 'blackjax_hmc_pathfinder', 'blackjax_nuts_pathfinder', - 'flowmc_rqspline_hmc', - 'flowmc_rqspline_mala', - 'flowmc_realnvp_hmc', - 'flowmc_realnvp_mala', 'numpyro_hmc', 'numpyro_nuts'], 'optimize': ['jaxopt_bfgs', @@ -124,10 +120,6 @@ mcmc .blackjax_nuts .blackjax_hmc_pathfinder .blackjax_nuts_pathfinder - .flowmc_rqspline_hmc - .flowmc_rqspline_mala - .flowmc_realnvp_hmc - .flowmc_realnvp_mala .numpyro_hmc .numpyro_nuts optimize @@ -216,10 +208,6 @@ blackjax_meads_hmc blackjax_nuts blackjax_hmc_pathfinder blackjax_nuts_pathfinder -flowmc_rqspline_hmc -flowmc_rqspline_mala -flowmc_realnvp_hmc -flowmc_realnvp_mala numpyro_hmc numpyro_nuts ``` diff --git a/pyproject.toml b/pyproject.toml index 0998a1c..832ea40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "optax", "optimistix", "blackjax", - "flowmc>=0.3.0", "numpyro", "jaxopt", "pymc",