Skip to content

Commit 8147aee

Browse files
committed
Add tests and multiple breaking changes
1 parent 38da648 commit 8147aee

File tree

8 files changed

+344
-133
lines changed

8 files changed

+344
-133
lines changed

bayesflow/amortizers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tensorflow.keras import Model
2424

2525
from bayesflow.exceptions import ConfigurationError, SummaryStatsError
26-
from bayesflow.losses import log_loss, mmd_summary_space
26+
from bayesflow.losses import log_loss, mmd_summary_space, kl_dirichlet
2727
from bayesflow.default_settings import DEFAULT_KEYS
2828

2929
import tensorflow_probability as tfp
@@ -870,7 +870,7 @@ def __init__(self, evidence_net, summary_net=None, loss_fun=None, kl_weight=None
870870
self.summary_net = summary_net
871871
self.loss = self._determine_loss(loss_fun)
872872
self.kl_weight = kl_weight
873-
self.n_models = self.evidence_net.n_models
873+
self.num_models = self.evidence_net.num_models
874874

875875
def __call__(self, input_dict, return_summary=False, **kwargs):
876876
""" Performs a forward pass through both networks.
@@ -887,7 +887,7 @@ def __call__(self, input_dict, return_summary=False, **kwargs):
887887
888888
Returns
889889
-------
890-
net_out : tf.Tensor of shape (batch_size, n_models) or tuple of (net_out (batch_size, n_models),
890+
net_out : tf.Tensor of shape (batch_size, num_models) or tuple of (net_out (batch_size, num_models),
891891
summary_out (batch_size, summary_dim)), the latter being the summary network outputs, if
892892
`return_summary` set to True.
893893
"""
@@ -905,7 +905,7 @@ def __call__(self, input_dict, return_summary=False, **kwargs):
905905
return net_out, summary_out
906906

907907
def compute_loss(self, input_dict, **kwargs):
908-
"""Computes the loss of the amortized model comparison.
908+
"""Computes the loss of the amortized model comparison instance.
909909
910910
Parameters
911911
----------
@@ -929,7 +929,7 @@ def compute_loss(self, input_dict, **kwargs):
929929
return loss + kl
930930

931931
def sample(self, input_dict, to_numpy=True, **kwargs):
932-
""" Samples posterior model probabilities from the higher order Dirichlet density.
932+
"""Samples posterior model probabilities from the higher order Dirichlet density.
933933
934934
Parameters
935935
----------
@@ -946,7 +946,7 @@ def sample(self, input_dict, to_numpy=True, **kwargs):
946946
Returns
947947
-------
948948
pm_samples : tf.Tensor or np.array
949-
The posterior draws from the Dirichlet distribution, shape (n_samples, n_batch, n_models)
949+
The posterior draws from the Dirichlet distribution, shape (num_samples, num_batch, num_models)
950950
"""
951951

952952
_, full_cond = self._compute_summary_condition(
@@ -958,7 +958,8 @@ def sample(self, input_dict, to_numpy=True, **kwargs):
958958
return self.evidence_net.sample(full_cond, to_numpy, **kwargs)
959959

960960
def evidence(self, input_dict, to_numpy=True, **kwargs):
961-
"""TODO"""
961+
"""Computes the evidence for the competing models given the data sets
962+
contained in `input_dict`."""
962963

963964
_, full_cond = self._compute_summary_condition(
964965
input_dict.get(DEFAULT_KEYS['summary_conditions']),
@@ -972,7 +973,7 @@ def evidence(self, input_dict, to_numpy=True, **kwargs):
972973
return alphas
973974

974975
def uncertainty_score(self, input_dict, to_numpy=True, **kwargs):
975-
"""TODO"""
976+
"""Computes the uncertainy score according to sum(alphas) / num_models."""
976977

977978
_, full_cond = self._compute_summary_condition(
978979
input_dict.get(DEFAULT_KEYS['summary_conditions']),
@@ -981,7 +982,7 @@ def uncertainty_score(self, input_dict, to_numpy=True, **kwargs):
981982
)
982983

983984
alphas = self(full_cond, return_summary=False, **kwargs)
984-
u = tf.reduce_sum(alphas, axis=-1) / self.evidence_net.n_models
985+
u = tf.reduce_sum(alphas, axis=-1) / self.evidence_net.num_models
985986
if to_numpy:
986987
return u.numpy()
987988
return u
@@ -1007,7 +1008,7 @@ def _compute_summary_condition(self, summary_conditions, direct_conditions, **kw
10071008
return sum_condition, full_cond
10081009

10091010
def _determine_loss(self, loss_fun):
1010-
""" Helper method to determine loss function to use."""
1011+
"""Helper method to determine loss function to use."""
10111012

10121013
if loss_fun is None:
10131014
return log_loss
@@ -1026,4 +1027,4 @@ def __init_subclass__(cls, **kwargs):
10261027

10271028
def __init__(self, *args, **kwargs):
10281029
warn(f'{self.__class__.__name__} will be deprecated. Use `AmortizedPosterior` instead.', DeprecationWarning, stacklevel=2)
1029-
super().__init__(*args, **kwargs)
1030+
super().__init__(*args, **kwargs)

bayesflow/default_settings.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
8484
}
8585

8686

87+
DEFAULT_SETTINGS_DENSE_EVIDENTIAL = {
88+
'units': 128,
89+
'kernel_initializer': 'lecun_normal',
90+
'activation': 'selu',
91+
}
92+
93+
8794
DEFAULT_SETTING_DENSE_COUPLING = MetaDictSetting(
8895
meta_dict={
8996
't_args': {
@@ -120,10 +127,10 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
120127
DEFAULT_SETTING_EVIDENTIAL_NET = MetaDictSetting(
121128
meta_dict={
122129
'dense_args': dict(units=128, kernel_initializer='lecun_normal', activation='selu'),
123-
'n_dense': 3,
130+
'num_dense': 3,
124131
'output_activation': 'softplus'
125132
},
126-
mandatory_fields=["n_models"]
133+
mandatory_fields=["num_models"]
127134
)
128135

129136

bayesflow/inference_networks.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -209,39 +209,53 @@ def create_config(cls, **kwargs):
209209

210210
class EvidentialNetwork(tf.keras.Model):
211211
"""Implements a network whose outputs are the concentration parameters of a Dirichlet density.
212+
213+
Follows ideas from:
212214
213-
Follows the implementation from:
214-
https://arxiv.org/abs/2004.10629
215+
[1] Radev, S. T., D'Alessandro, M., Mertens, U. K., Voss, A., Köthe, U., & Bürkner, P. C. (2021).
216+
Amortized Bayesian model comparison with evidential deep learning.
217+
IEEE Transactions on Neural Networks and Learning Systems.
218+
219+
[2] Sensoy, M., Kaplan, L., & Kandemir, M. (2018).
220+
Evidential deep learning to quantify classification uncertainty.
221+
Advances in neural information processing systems, 31.
215222
"""
216223

217-
def __init__(self, meta={}):
224+
def __init__(self, num_models, dense_args=None, num_dense=3, output_activation='softplus', **kwargs):
218225
"""Creates an instance of an evidential network for amortized model comparison.
219226
220227
Parameters
221228
----------
222-
meta : dict
223-
A list of dictionaries, where each dictionary holds parameter-value pairs
224-
for a single :class:`tf.keras.Dense` layer
229+
num_models : int
230+
The number of candidate (competing models) for the comparison scenario.
231+
dense_args : dict or None, optional, default: None
232+
The arguments for a tf.keras.layers.Dense layer. If None, defaults will be used.
233+
num_dense : int, optional, default: 3
234+
The number of dense layers for the main network part.
235+
output_activation : str or callable, optional, default: 'softplus'
236+
The activation function to use for the network outputs.
237+
Important: needs to have positive outputs.
238+
**kwargs : dict, optional, default: {}
239+
Optional keyword arguments (e.g., name) passed to the tf.keras.Model __init__ method.
225240
"""
226241

227-
super().__init__()
242+
super().__init__(**kwargs)
228243

229-
# Create settings dictionary
230-
meta = build_meta_dict(user_dict=meta,
231-
default_setting=default_settings.DEFAULT_SETTING_EVIDENTIAL_NET)
244+
if dense_args is None:
245+
dense_args = default_settings.DEFAULT_SETTINGS_DENSE_EVIDENTIAL
232246

233247
# A network to increase representation power
234248
self.dense = tf.keras.Sequential([
235-
tf.keras.layers.Dense(**meta['dense_args'])
236-
for _ in range(meta['n_dense'])
249+
tf.keras.layers.Dense(**dense_args)
250+
for _ in range(num_dense)
237251
])
238252

239253
# The layer to output model evidences
240-
self.evidence_layer = tf.keras.layers.Dense(
241-
meta['n_models'], activation=meta['output_activation'],
242-
**{k: v for k, v in meta['dense_args'].items() if k != 'units' and k != 'activation'})
254+
self.alpha_layer = tf.keras.layers.Dense(
255+
num_models, activation=output_activation,
256+
**{k: v for k, v in dense_args.items() if k != 'units' and k != 'activation'})
243257

244-
self.n_models = meta['n_models']
258+
self.num_models = num_models
245259

246260
def call(self, condition, **kwargs):
247261
"""Computes evidences for model comparison given a batch of data and optional concatenated context,
@@ -254,13 +268,17 @@ def call(self, condition, **kwargs):
254268
255269
Returns
256270
-------
257-
alpha : tf.Tensor of shape (batch_size, n_models) -- the learned model evidences
271+
evidence : tf.Tensor of shape (batch_size, num_models) -- the learned model evidences
258272
"""
259273

274+
return self.evidence(condition, **kwargs)
275+
276+
@tf.function
277+
def evidence(self, condition, **kwargs):
260278
rep = self.dense(condition, **kwargs)
261-
evidence = self.evidence_layer(rep, **kwargs)
262-
alpha = evidence + 1
263-
return alpha
279+
alpha = self.alpha_layer(rep, **kwargs)
280+
evidence = alpha + 1.
281+
return evidence
264282

265283
def sample(self, condition, n_samples, **kwargs):
266284
"""Samples posterior model probabilities from the higher-order Dirichlet density.
@@ -271,17 +289,24 @@ def sample(self, condition, n_samples, **kwargs):
271289
The summary of the observed (or simulated) data, shape (n_data_sets, ...)
272290
n_samples : int
273291
Number of samples to obtain from the approximate posterior
274-
292+
275293
Returns
276294
-------
277295
pm_samples : tf.Tensor or np.array
278-
The posterior draws from the Dirichlet distribution, shape (n_samples, n_batch, n_models)
296+
The posterior draws from the Dirichlet distribution, shape (num_samples, num_batch, num_models)
279297
"""
280298

281-
# Compute evidential values
282-
alpha = self(condition, **kwargs)
299+
alpha = self.evidence(condition, **kwargs)
283300
n_datasets = alpha.shape[0]
284-
285-
# Sample for each dataset
286-
pm_samples = np.stack([np.random.dirichlet(alpha[n, :], size=n_samples) for n in range(n_datasets)], axis=1)
301+
pm_samples = np.stack(
302+
[np.default_rng().dirichlet(alpha[n, :], size=n_samples) for n in range(n_datasets)], axis=1)
287303
return pm_samples
304+
305+
@classmethod
306+
def create_config(cls, **kwargs):
307+
""""Used to create the settings dictionary for the internal networks of the invertible
308+
network. Will fill in missing """
309+
310+
settings = build_meta_dict(user_dict=kwargs,
311+
default_setting=default_settings.DEFAULT_SETTING_EVIDENTIAL_NET)
312+
return settings

bayesflow/simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,15 +848,15 @@ def __init__(self, generative_models: list, model_probs='equal'):
848848
"""
849849

850850
self.generative_models = generative_models
851-
self.n_models = len(generative_models)
851+
self.num_models = len(generative_models)
852852
self.model_prior = self._determine_model_prior(model_probs)
853853

854854
def _determine_model_prior(self, model_probs):
855855
"""Creates the model prior p(M) given user input."""
856856

857857
if model_probs == 'equal':
858-
return lambda b: np.random.randint(self.n_models, size=b)
859-
return lambda b: np.random.default_rng().choice(self.n_models, size=b, p=model_probs)
858+
return lambda b: np.random.randint(self.num_models, size=b)
859+
return lambda b: np.random.default_rng().choice(self.num_models, size=b, p=model_probs)
860860

861861
def __call__(self, batch_size, **kwargs):
862862

0 commit comments

Comments
 (0)