From a1468018557297b6ce3704db746bd90783d8f608 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Mon, 28 Apr 2025 13:55:39 +0200 Subject: [PATCH 1/8] drop or fill missing keys from the output --- .../simulators/model_comparison_simulator.py | 78 ++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 60174ef92..aa52571d0 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -6,6 +6,7 @@ from bayesflow.utils.decorators import allow_batch_size from bayesflow.utils import numpy_utils as npu +from bayesflow.utils import logging from types import FunctionType @@ -22,6 +23,7 @@ def __init__( p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, + key_conflicts: str | float = "drop", shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): """ @@ -38,8 +40,14 @@ def __init__( A sequence of logits corresponding to model probabilities. Mutually exclusive with `p`. If neither `p` nor `logits` is provided, defaults to uniform logits. use_mixed_batches : bool, optional - If True, samples in a batch are drawn from different models. If False, the entire batch - is drawn from a single model chosen according to the model probabilities. Default is True. + Whether to draw samples in a batch from different models. + - If True (default), each sample in a batch may come from a different model. + - If False, the entire batch is drawn from a single model, selected according to model probabilities. + key_conflicts : {"drop"} | float, optional + Policy for handling keys that are missing in the output of some models, when using mixed batches. + - "drop" (default): Drop conflicting keys from the batch output. + - float: Fill missing keys with the specified value. + - If neither "drop" nor a float is given, an error is raised when key conflicts are detected. shared_simulator : Simulator or Callable, optional A shared simulator whose outputs are passed to all model simulators. If a function is provided, it is wrapped in a `LambdaSimulator` with batching enabled. @@ -68,6 +76,8 @@ def __init__( self.logits = logits self.use_mixed_batches = use_mixed_batches + self.key_conflicts = key_conflicts + self._keys = None @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: @@ -105,6 +115,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: sims = [ simulator.sample(n, **(kwargs | data)) for simulator, n in zip(self.simulators, model_counts) if n > 0 ] + sims = self._handle_key_conflicts(sims, model_counts) sims = tree_concatenate(sims, numpy=True) data |= sims @@ -118,3 +129,66 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: model_indices = npu.one_hot(np.full(batch_shape, model_index, dtype="int32"), num_models) return data | {"model_indices": model_indices} + + def _handle_key_conflicts(self, sims, batch_sizes): + batch_sizes = [b for b in batch_sizes if b > 0] + + keys, all_keys, common_keys, missing_keys = self._determine_key_conflicts(sims=sims) + + # all sims have the same keys + if all_keys == common_keys: + return sims + + # keep only common keys + if self.key_conflicts == "drop": + sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims] + return sims + + # try to fill values with key_conflicts to shape of sims from other models + if isinstance(self.key_conflicts, (float, int)): + combined_sims = {} + for sim in sims: + combined_sims = combined_sims | sim + + for i, sim in enumerate(sims): + for missing_key in missing_keys[i]: + shape = combined_sims[missing_key].shape + shape = [s for s in shape] + shape[0] = batch_sizes[i] + + sim[missing_key] = np.full(shape=shape, fill_value=self.key_conflicts) + + return sims + + raise ValueError( + "Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided." + ) + + def _determine_key_conflicts(self, sims): + # determine only once + if self._keys is not None: + return self._keys + + keys = [set(sim.keys()) for sim in sims] + all_keys = set.union(*keys) + common_keys = set.intersection(*keys) + missing_keys = [all_keys - k for k in keys] + + self._keys = keys, all_keys, common_keys, missing_keys + + if all_keys == common_keys: + return self._keys + + if self.key_conflicts == "drop": + logging.info( + f"Incompatible simulator output. \ +The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}." + ) + elif isinstance(self.key_conflicts, (float, int)): + logging.info( + f"Incompatible simulator output. \ +Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \ +with value {self.key_conflicts}." + ) + + return self._keys From 00e79a6821c5b0e2eaa0bdef30884368ab3af821 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Tue, 29 Apr 2025 09:28:03 +0200 Subject: [PATCH 2/8] fix typo --- bayesflow/simulators/model_comparison_simulator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index aa52571d0..6226f03cf 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -144,7 +144,7 @@ def _handle_key_conflicts(self, sims, batch_sizes): sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims] return sims - # try to fill values with key_conflicts to shape of sims from other models + # try to fill with key_conflicts to shape of the values from other model if isinstance(self.key_conflicts, (float, int)): combined_sims = {} for sim in sims: @@ -153,7 +153,7 @@ def _handle_key_conflicts(self, sims, batch_sizes): for i, sim in enumerate(sims): for missing_key in missing_keys[i]: shape = combined_sims[missing_key].shape - shape = [s for s in shape] + shape = list(shape) shape[0] = batch_sizes[i] sim[missing_key] = np.full(shape=shape, fill_value=self.key_conflicts) From 1a1868b56f42a5ead3d9a6b5cf4a4cc8d1911911 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Tue, 29 Apr 2025 09:58:59 +0200 Subject: [PATCH 3/8] add test --- tests/test_simulators/conftest.py | 26 ++++++++++++++++++++++++ tests/test_simulators/test_simulators.py | 10 +++++++++ 2 files changed, 36 insertions(+) diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 0e76a5396..be92cd6df 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -167,6 +167,32 @@ def likelihood(mu, n): return make_simulator([prior, likelihood], meta_fn=context) +@pytest.fixture(params=["drop", np.nan]) +def multimodel(request): + from bayesflow.simulators import make_simulator, ModelComparisonSimulator + + rng = np.random.default_rng() + + def prior_1(): + return dict(w=rng.uniform()) + + def prior_2(): + return dict(c=rng.uniform()) + + def model_1(w): + return dict(x=w) + + def model_2(c): + return dict(x=c) + + simulator_1 = make_simulator([prior_1, model_1]) + simulator_2 = make_simulator([prior_2, model_2]) + + simulator = ModelComparisonSimulator(simulators=[simulator_1, simulator_2], key_conflicts=request.param) + + return simulator + + @pytest.fixture() def fixed_n(): return 5 diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index e9a3c80c0..0c9248290 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -47,3 +47,13 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu): assert samples["mu"].shape == (batch_size, 1) assert np.all(samples["mu"] == fixed_mu) assert samples["y"].shape == (batch_size, fixed_n) + + +def test_multimodel_sample(multimodel, batch_size): + samples = multimodel.sample(batch_size) + + if multimodel.key_conflicts == "drop": + assert set(samples) == {"x", "model_indices"} + else: + assert set(samples) == {"x", "model_indices", "c", "w"} + assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size From 7f33a250d4c01b55ce70f7d0164e856a321b0991 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 30 Apr 2025 09:35:07 +0200 Subject: [PATCH 4/8] address code review from Lars --- .../simulators/model_comparison_simulator.py | 32 ++++++++----------- tests/test_simulators/conftest.py | 28 ++++++++++++++-- tests/test_simulators/test_simulators.py | 16 ++++++++-- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 6226f03cf..0fcda3c57 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -23,7 +23,8 @@ def __init__( p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, - key_conflicts: str | float = "drop", + key_conflicts: str = "drop", + fill_value: float = np.nan, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): """ @@ -43,11 +44,13 @@ def __init__( Whether to draw samples in a batch from different models. - If True (default), each sample in a batch may come from a different model. - If False, the entire batch is drawn from a single model, selected according to model probabilities. - key_conflicts : {"drop"} | float, optional + key_conflicts : str, optional Policy for handling keys that are missing in the output of some models, when using mixed batches. - "drop" (default): Drop conflicting keys from the batch output. - - float: Fill missing keys with the specified value. - - If neither "drop" nor a float is given, an error is raised when key conflicts are detected. + - "fill": Fill missing keys with the specified value. + - "error": An error is raised when key conflicts are detected. + fill_value : float, optional + If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument. shared_simulator : Simulator or Callable, optional A shared simulator whose outputs are passed to all model simulators. If a function is provided, it is wrapped in a `LambdaSimulator` with batching enabled. @@ -77,6 +80,7 @@ def __init__( self.logits = logits self.use_mixed_batches = use_mixed_batches self.key_conflicts = key_conflicts + self.fill_value = fill_value self._keys = None @allow_batch_size @@ -139,30 +143,22 @@ def _handle_key_conflicts(self, sims, batch_sizes): if all_keys == common_keys: return sims - # keep only common keys if self.key_conflicts == "drop": sims = [{k: v for k, v in sim.items() if k in common_keys} for sim in sims] return sims - - # try to fill with key_conflicts to shape of the values from other model - if isinstance(self.key_conflicts, (float, int)): + elif self.key_conflicts == "fill": combined_sims = {} for sim in sims: combined_sims = combined_sims | sim - for i, sim in enumerate(sims): for missing_key in missing_keys[i]: shape = combined_sims[missing_key].shape shape = list(shape) shape[0] = batch_sizes[i] - - sim[missing_key] = np.full(shape=shape, fill_value=self.key_conflicts) - + sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value) return sims - - raise ValueError( - "Key conflicts are found in model simulations and no valid `key_conflicts` policy was provided." - ) + elif self.key_conflicts == "error": + raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.") def _determine_key_conflicts(self, sims): # determine only once @@ -184,11 +180,11 @@ def _determine_key_conflicts(self, sims): f"Incompatible simulator output. \ The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}." ) - elif isinstance(self.key_conflicts, (float, int)): + elif self.key_conflicts == "fill": logging.info( f"Incompatible simulator output. \ Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \ -with value {self.key_conflicts}." +with value {self.fill_value}." ) return self._keys diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index be92cd6df..7dcc22c12 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -167,8 +167,32 @@ def likelihood(mu, n): return make_simulator([prior, likelihood], meta_fn=context) -@pytest.fixture(params=["drop", np.nan]) -def multimodel(request): +@pytest.fixture() +def multimodel(): + from bayesflow.simulators import make_simulator, ModelComparisonSimulator + + def context(batch_size): + return dict(n=np.random.randint(10, 100)) + + def prior_0(): + return dict(mu=0) + + def prior_1(): + return dict(mu=np.random.standard_normal()) + + def likelihood(n, mu): + return dict(y=np.random.normal(mu, 1, n)) + + simulator_0 = make_simulator([prior_0, likelihood]) + simulator_1 = make_simulator([prior_1, likelihood]) + + simulator = ModelComparisonSimulator(simulators=[simulator_0, simulator_1], shared_simulator=context) + + return simulator + + +@pytest.fixture(params=["drop", "fill", "error"]) +def multimodel_key_conflicts(request): from bayesflow.simulators import make_simulator, ModelComparisonSimulator rng = np.random.default_rng() diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index 0c9248290..4e2174be3 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -1,3 +1,4 @@ +import pytest import keras import numpy as np @@ -52,8 +53,19 @@ def test_fixed_sample(composite_gaussian, batch_size, fixed_n, fixed_mu): def test_multimodel_sample(multimodel, batch_size): samples = multimodel.sample(batch_size) - if multimodel.key_conflicts == "drop": + assert set(samples) == {"n", "mu", "y", "model_indices"} + assert samples["mu"].shape == (batch_size, 1) + assert samples["y"].shape == (batch_size, samples["n"]) + + +def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size): + if multimodel_key_conflicts.key_conflicts == "drop": + samples = multimodel_key_conflicts.sample(batch_size) assert set(samples) == {"x", "model_indices"} - else: + elif multimodel_key_conflicts.key_conflicts == "fill": + samples = multimodel_key_conflicts.sample(batch_size) assert set(samples) == {"x", "model_indices", "c", "w"} assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size + elif multimodel_key_conflicts.key_conflicts == "error": + with pytest.raises(Exception): + samples = multimodel_key_conflicts.sample(batch_size) From b486784ec8ba4aa2d59b6e5267234b5c7b65aeb2 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Wed, 30 Apr 2025 15:34:34 +0000 Subject: [PATCH 5/8] formatting in the docstring add newlines to correctly render lists, make reference to other class a link --- bayesflow/simulators/model_comparison_simulator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 0fcda3c57..fb1aae098 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -42,10 +42,12 @@ def __init__( If neither `p` nor `logits` is provided, defaults to uniform logits. use_mixed_batches : bool, optional Whether to draw samples in a batch from different models. + - If True (default), each sample in a batch may come from a different model. - If False, the entire batch is drawn from a single model, selected according to model probabilities. key_conflicts : str, optional Policy for handling keys that are missing in the output of some models, when using mixed batches. + - "drop" (default): Drop conflicting keys from the batch output. - "fill": Fill missing keys with the specified value. - "error": An error is raised when key conflicts are detected. @@ -53,7 +55,7 @@ def __init__( If `key_conflicts=="fill"`, the missing keys will be filled with the value of this argument. shared_simulator : Simulator or Callable, optional A shared simulator whose outputs are passed to all model simulators. If a function is - provided, it is wrapped in a `LambdaSimulator` with batching enabled. + provided, it is wrapped in a :py:class:`~bayesflow.simulators.LambdaSimulator` with batching enabled. """ self.simulators = simulators From 39de28788d0cfddb30dd3f5263a09acbadfc55b6 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Mon, 12 May 2025 10:08:44 +0200 Subject: [PATCH 6/8] check keys every time, issue warning only once --- .../simulators/model_comparison_simulator.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index fb1aae098..72dce01ec 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -83,7 +83,7 @@ def __init__( self.use_mixed_batches = use_mixed_batches self.key_conflicts = key_conflicts self.fill_value = fill_value - self._keys = None + self._key_conflicts_warning = True @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: @@ -163,30 +163,28 @@ def _handle_key_conflicts(self, sims, batch_sizes): raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.") def _determine_key_conflicts(self, sims): - # determine only once - if self._keys is not None: - return self._keys - keys = [set(sim.keys()) for sim in sims] all_keys = set.union(*keys) common_keys = set.intersection(*keys) missing_keys = [all_keys - k for k in keys] - self._keys = keys, all_keys, common_keys, missing_keys - if all_keys == common_keys: - return self._keys + return keys, all_keys, common_keys, missing_keys - if self.key_conflicts == "drop": - logging.info( - f"Incompatible simulator output. \ + if self._key_conflicts_warning: + # issue warning only once + self._key_conflicts_warning = False + + if self.key_conflicts == "drop": + logging.info( + f"Incompatible simulator output. \ The following keys will be dropped: {', '.join(sorted(all_keys - common_keys))}." - ) - elif self.key_conflicts == "fill": - logging.info( - f"Incompatible simulator output. \ + ) + elif self.key_conflicts == "fill": + logging.info( + f"Incompatible simulator output. \ Attempting to replace keys: {', '.join(sorted(all_keys - common_keys))}, where missing, \ with value {self.fill_value}." - ) + ) - return self._keys + return keys, all_keys, common_keys, missing_keys From fc7cf0d6fd86765f2f53eef6d952481ef7bb100b Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Mon, 12 May 2025 10:12:33 +0200 Subject: [PATCH 7/8] more specific error check --- tests/test_simulators/test_simulators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simulators/test_simulators.py b/tests/test_simulators/test_simulators.py index 4e2174be3..f1996c82e 100644 --- a/tests/test_simulators/test_simulators.py +++ b/tests/test_simulators/test_simulators.py @@ -67,5 +67,5 @@ def test_multimodel_key_conflicts_sample(multimodel_key_conflicts, batch_size): assert set(samples) == {"x", "model_indices", "c", "w"} assert np.sum(np.isnan(samples["c"])) + np.sum(np.isnan(samples["w"])) == batch_size elif multimodel_key_conflicts.key_conflicts == "error": - with pytest.raises(Exception): + with pytest.raises(ValueError): samples = multimodel_key_conflicts.sample(batch_size) From d10fe111d1bdd4b1b73d9e4d49432c3bc2ac2dca Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 13 May 2025 05:23:45 +0000 Subject: [PATCH 8/8] [no ci] minor edits to types and error message --- bayesflow/simulators/model_comparison_simulator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bayesflow/simulators/model_comparison_simulator.py b/bayesflow/simulators/model_comparison_simulator.py index 72dce01ec..4b1d4095a 100644 --- a/bayesflow/simulators/model_comparison_simulator.py +++ b/bayesflow/simulators/model_comparison_simulator.py @@ -9,6 +9,7 @@ from bayesflow.utils import logging from types import FunctionType +from typing import Literal from .simulator import Simulator from .lambda_simulator import LambdaSimulator @@ -23,7 +24,7 @@ def __init__( p: Sequence[float] = None, logits: Sequence[float] = None, use_mixed_batches: bool = True, - key_conflicts: str = "drop", + key_conflicts: Literal["drop", "fill", "error"] = "drop", fill_value: float = np.nan, shared_simulator: Simulator | Callable[[Sequence[int]], dict[str, any]] = None, ): @@ -160,7 +161,9 @@ def _handle_key_conflicts(self, sims, batch_sizes): sim[missing_key] = np.full(shape=shape, fill_value=self.fill_value) return sims elif self.key_conflicts == "error": - raise ValueError("Key conflicts are found in simulator outputs, cannot combine them into one batch.") + raise ValueError( + "Different simulators provide outputs with different keys, cannot combine them into one batch." + ) def _determine_key_conflicts(self, sims): keys = [set(sim.keys()) for sim in sims]