Skip to content

Commit 9768a4d

Browse files
committed
Update defaults
1 parent 31a6251 commit 9768a4d

File tree

3 files changed

+508
-7
lines changed

3 files changed

+508
-7
lines changed

bayesflow/amortizers.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,179 @@ def _determine_loss(self, loss_fun):
10361036
)
10371037

10381038

1039+
class TwoLevelAmortizedPosterior(tf.keras.Model, AmortizedTarget):
1040+
"""An interface for estimating arbitrary two level hierarchical Bayesian models."""
1041+
1042+
def __init__(self, local_amortizer, global_amortizer, summary_net=None, **kwargs):
1043+
"""Creates an wrapper for estimating two-level hierarchical Bayesian models.
1044+
1045+
Parameters
1046+
----------
1047+
local_amortizer : bayesflow.amortizers.AmortizedPosterior
1048+
A posterior amortizer without a summary network which will estimate
1049+
the full conditional of the (varying numbers of) local parameter vectors.
1050+
global_amortizer : bayesflow.amortizers.AmortizedPosterior
1051+
A posterior amortizer without a summary network which will estimate the joint
1052+
posterior of hyperparameters and optional shared parameters given a representation
1053+
of an entire hierarchical data set. If both hyper- and shared parameters are present,
1054+
the first dimensions correspond to the hyperparameters and the remaining ones correspond
1055+
to the shared parameters.
1056+
summary_net : tf.keras.Model or None, optional, default: None
1057+
An optional summary network to compress non-vector data structures.
1058+
**kwargs : dict, optional, default: {}
1059+
Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
1060+
"""
1061+
1062+
super().__init__(**kwargs)
1063+
1064+
self.local_amortizer = local_amortizer
1065+
self.global_amortizer = global_amortizer
1066+
self.summary_net = summary_net
1067+
1068+
def call(self, input_dict, **kwargs):
1069+
"""Forward pass through the hierarchical amortized posterior."""
1070+
1071+
local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
1072+
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
1073+
local_out = self.local_amortizer(local_inputs, **kwargs)
1074+
global_out = self.global_amortizer(global_inputs, **kwargs)
1075+
return local_out, global_out
1076+
1077+
def compute_loss(self, input_dict, **kwargs):
1078+
"""Compute loss of all amortizers."""
1079+
1080+
local_summaries, global_summaries = self._compute_condition(input_dict, **kwargs)
1081+
local_inputs, global_inputs = self._prepare_inputs(input_dict, local_summaries, global_summaries)
1082+
local_loss = self.local_amortizer.compute_loss(local_inputs, **kwargs)
1083+
global_loss = self.global_amortizer.compute_loss(global_inputs, **kwargs)
1084+
return {"Local.Loss": local_loss, "Global.Loss": global_loss}
1085+
1086+
def sample(self, input_dict, n_samples, to_numpy=True, **kwargs):
1087+
"""Obtains samples from the joint hierarchical posterior given observations.
1088+
1089+
Important: Currently works only for single hierarchical data sets!
1090+
1091+
Parameters
1092+
----------
1093+
input_dict : dict
1094+
Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
1095+
`summary_conditions` - the hierarchical data set (to be embedded by the summary net)
1096+
As well as optional keys:
1097+
`direct_local_conditions` - (Context) variables used to condition the local posterior
1098+
`direct_global_conditions` - (Context) variables used to condition the global posterior
1099+
n_samples : int
1100+
The number of posterior draws (samples) to obtain from the approximate posterior
1101+
to_numpy : bool, optional, default: True
1102+
Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
1103+
**kwargs : dict, optional, default: {}
1104+
Additional keyword arguments passed to the summary network as the amortizers
1105+
1106+
Returns:
1107+
--------
1108+
samples_dict : dict
1109+
A dictionary with keys `global_samples` and `local_samples`
1110+
Global samples will hold an array-like of shape (num_samples, num_replicas, num_local)
1111+
and local samples will hold an array-like of shape (num_samples, num_hyper + num_shared),
1112+
if optional shared patameters are present, otherwise (num_samples, num_hyper),
1113+
"""
1114+
1115+
# Returned shapes will be
1116+
# local_summaries.shape = (1, num_groups, summary_dim_local)
1117+
# global_summaries.shape = (1, summary_dim_global)
1118+
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)
1119+
num_groups = local_summaries.shape[1]
1120+
1121+
if local_summaries.shape[0] != 1 or global_summaries.shape[0] != 1:
1122+
raise NotImplementedError("Method currently supports only single hierarchical data sets!")
1123+
1124+
# Obtain samples from p(global | all_data)
1125+
inp_global = {DEFAULT_KEYS["direct_conditions"]: global_summaries}
1126+
1127+
# New, shape will be (n_samples, num_globals)
1128+
global_samples = self.global_amortizer.sample(inp_global, n_samples, **kwargs, to_numpy=False)
1129+
1130+
# Repeat local conditions for n_samples
1131+
# New shape -> (num_groups, n_samples, summary_dim_local)
1132+
local_summaries = tf.stack([tf.squeeze(local_summaries, axis=0)] * n_samples, axis=1)
1133+
1134+
# Repeat global samples for num_groups
1135+
# New shape -> (num_groups, n_samples, num_globals)
1136+
global_samples_rep = tf.stack([global_samples] * num_groups, axis=0)
1137+
1138+
# Concatenate local summaries with global samples
1139+
# New shape -> (num_groups, num_samples, summary_dim_local + num_globals)
1140+
local_summaries = tf.concat([local_summaries, global_samples_rep], axis=-1)
1141+
1142+
# Obtain samples from p(local_i | data_i, global_i)
1143+
inp_local = {DEFAULT_KEYS["direct_conditions"]: local_summaries}
1144+
local_samples = self.local_amortizer.sample(inp_local, n_samples, to_numpy=False, **kwargs)
1145+
1146+
if to_numpy:
1147+
global_samples = global_samples.numpy()
1148+
local_samples = local_samples.numpy()
1149+
1150+
return {"global_samples": global_samples, "local_samples": local_samples}
1151+
1152+
def log_prob(self, input_dict):
1153+
"""Compute normalized log density."""
1154+
1155+
raise NotImplementedError
1156+
1157+
def _prepare_inputs(self, input_dict, local_summaries, global_summaries):
1158+
"""Prepare input dictionaries for both amortizers."""
1159+
1160+
# Prepare inputs for local amortizer
1161+
local_inputs = {"direct_conditions": local_summaries, "parameters": input_dict["local_parameters"]}
1162+
1163+
# Prepare inputs for global amortizer
1164+
_parameters = input_dict["hyper_parameters"]
1165+
if input_dict.get("shared_parameters") is not None:
1166+
_parameters = tf.concat([_parameters, input_dict.get("shared_parameters")], axis=-1)
1167+
global_inputs = {"direct_conditions": global_summaries, "parameters": _parameters}
1168+
return local_inputs, global_inputs
1169+
1170+
def _compute_condition(self, input_dict, **kwargs):
1171+
"""Determines conditionining variables for both amortizers."""
1172+
1173+
# Obtain needed summaries
1174+
local_summaries, global_summaries = self._get_local_global(input_dict, **kwargs)
1175+
1176+
# At this point, add globals as conditions
1177+
num_locals = local_summaries.shape[1]
1178+
1179+
# Add hyper parameters as conditions:
1180+
# p(local_n | data_n, hyper)
1181+
if input_dict.get("hyper_parameters") is not None:
1182+
_params = input_dict.get("hyper_parameters")
1183+
_conds = tf.stack([_params] * num_locals, axis=1)
1184+
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
1185+
# Add shared parameters as conditions:
1186+
# p(local_n | data_n, hyper, shared)
1187+
if input_dict.get("shared_parameters") is not None:
1188+
_params = input_dict.get("shared_parameters")
1189+
_conds = tf.stack([_params] * num_locals, axis=1)
1190+
local_summaries = tf.concat([local_summaries, _conds], axis=-1)
1191+
return local_summaries, global_summaries
1192+
1193+
def _get_local_global(self, input_dict, **kwargs):
1194+
"""Helper function to obtain local and global condition tensors."""
1195+
1196+
# Obtain summary conditions
1197+
if self.summary_net is not None:
1198+
local_summaries, global_summaries = self.summary_net(
1199+
input_dict["summary_conditions"], return_all=True, **kwargs
1200+
)
1201+
if input_dict.get("direct_local_conditions") is not None:
1202+
local_summaries = tf.concat([local_summaries, input_dict.get("direct_local_conditions")], axis=-1)
1203+
if input_dict.get("direct_global_conditions") is not None:
1204+
global_summaries = tf.concat([global_summaries, input_dict.get("direct_global_conditions")], axis=-1)
1205+
# If no summary net provided, assume direct conditions exist or fail
1206+
else:
1207+
local_summaries = input_dict.get("direct_local_conditions")
1208+
global_summaries = input_dict.get("direct_global_conditions")
1209+
return local_summaries, global_summaries
1210+
1211+
10391212
class SingleModelAmortizer(AmortizedPosterior):
10401213
"""Deprecated class for amortizer posterior estimation."""
10411214

bayesflow/default_settings.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,18 @@ def __init__(self, meta_dict: dict, mandatory_fields: list = []):
176176
"non_batchable_context": "non_batchable_context",
177177
"prior_batchable_context": "prior_batchable_context",
178178
"prior_non_batchable_context": "prior_non_batchable_context",
179+
"prior_context": "prior_context",
180+
"hyper_prior_draws": "hyper_prior_draws",
181+
"shared_prior_draws": "shared_prior_draws",
182+
"local_prior_draws": "local_prior_draws",
179183
"sim_batchable_context": "sim_batchable_context",
180184
"sim_non_batchable_context": "sim_non_batchable_context",
181185
"summary_conditions": "summary_conditions",
182186
"direct_conditions": "direct_conditions",
183187
"parameters": "parameters",
184-
"hyperparameters": "hyperparameters",
188+
"hyper_parameters": "hyper_parameters",
189+
"shared_parameters": "shared_parameters",
190+
"local_parameters": "local_parameters",
185191
"observables": "observables",
186192
"targets": "targets",
187193
"conditions": "conditions",

0 commit comments

Comments
 (0)