@@ -1036,6 +1036,179 @@ def _determine_loss(self, loss_fun):
10361036 )
10371037
10381038
1039+ class TwoLevelAmortizedPosterior (tf .keras .Model , AmortizedTarget ):
1040+ """An interface for estimating arbitrary two level hierarchical Bayesian models."""
1041+
1042+ def __init__ (self , local_amortizer , global_amortizer , summary_net = None , ** kwargs ):
1043+ """Creates an wrapper for estimating two-level hierarchical Bayesian models.
1044+
1045+ Parameters
1046+ ----------
1047+ local_amortizer : bayesflow.amortizers.AmortizedPosterior
1048+ A posterior amortizer without a summary network which will estimate
1049+ the full conditional of the (varying numbers of) local parameter vectors.
1050+ global_amortizer : bayesflow.amortizers.AmortizedPosterior
1051+ A posterior amortizer without a summary network which will estimate the joint
1052+ posterior of hyperparameters and optional shared parameters given a representation
1053+ of an entire hierarchical data set. If both hyper- and shared parameters are present,
1054+ the first dimensions correspond to the hyperparameters and the remaining ones correspond
1055+ to the shared parameters.
1056+ summary_net : tf.keras.Model or None, optional, default: None
1057+ An optional summary network to compress non-vector data structures.
1058+ **kwargs : dict, optional, default: {}
1059+ Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance.
1060+ """
1061+
1062+ super ().__init__ (** kwargs )
1063+
1064+ self .local_amortizer = local_amortizer
1065+ self .global_amortizer = global_amortizer
1066+ self .summary_net = summary_net
1067+
1068+ def call (self , input_dict , ** kwargs ):
1069+ """Forward pass through the hierarchical amortized posterior."""
1070+
1071+ local_summaries , global_summaries = self ._compute_condition (input_dict , ** kwargs )
1072+ local_inputs , global_inputs = self ._prepare_inputs (input_dict , local_summaries , global_summaries )
1073+ local_out = self .local_amortizer (local_inputs , ** kwargs )
1074+ global_out = self .global_amortizer (global_inputs , ** kwargs )
1075+ return local_out , global_out
1076+
1077+ def compute_loss (self , input_dict , ** kwargs ):
1078+ """Compute loss of all amortizers."""
1079+
1080+ local_summaries , global_summaries = self ._compute_condition (input_dict , ** kwargs )
1081+ local_inputs , global_inputs = self ._prepare_inputs (input_dict , local_summaries , global_summaries )
1082+ local_loss = self .local_amortizer .compute_loss (local_inputs , ** kwargs )
1083+ global_loss = self .global_amortizer .compute_loss (global_inputs , ** kwargs )
1084+ return {"Local.Loss" : local_loss , "Global.Loss" : global_loss }
1085+
1086+ def sample (self , input_dict , n_samples , to_numpy = True , ** kwargs ):
1087+ """Obtains samples from the joint hierarchical posterior given observations.
1088+
1089+ Important: Currently works only for single hierarchical data sets!
1090+
1091+ Parameters
1092+ ----------
1093+ input_dict : dict
1094+ Input dictionary containing the following mandatory keys, if DEFAULT_KEYS unchanged:
1095+ `summary_conditions` - the hierarchical data set (to be embedded by the summary net)
1096+ As well as optional keys:
1097+ `direct_local_conditions` - (Context) variables used to condition the local posterior
1098+ `direct_global_conditions` - (Context) variables used to condition the global posterior
1099+ n_samples : int
1100+ The number of posterior draws (samples) to obtain from the approximate posterior
1101+ to_numpy : bool, optional, default: True
1102+ Flag indicating whether to return the samples as a `np.array` or a `tf.Tensor`
1103+ **kwargs : dict, optional, default: {}
1104+ Additional keyword arguments passed to the summary network as the amortizers
1105+
1106+ Returns:
1107+ --------
1108+ samples_dict : dict
1109+ A dictionary with keys `global_samples` and `local_samples`
1110+ Global samples will hold an array-like of shape (num_samples, num_replicas, num_local)
1111+ and local samples will hold an array-like of shape (num_samples, num_hyper + num_shared),
1112+ if optional shared patameters are present, otherwise (num_samples, num_hyper),
1113+ """
1114+
1115+ # Returned shapes will be
1116+ # local_summaries.shape = (1, num_groups, summary_dim_local)
1117+ # global_summaries.shape = (1, summary_dim_global)
1118+ local_summaries , global_summaries = self ._get_local_global (input_dict , ** kwargs )
1119+ num_groups = local_summaries .shape [1 ]
1120+
1121+ if local_summaries .shape [0 ] != 1 or global_summaries .shape [0 ] != 1 :
1122+ raise NotImplementedError ("Method currently supports only single hierarchical data sets!" )
1123+
1124+ # Obtain samples from p(global | all_data)
1125+ inp_global = {DEFAULT_KEYS ["direct_conditions" ]: global_summaries }
1126+
1127+ # New, shape will be (n_samples, num_globals)
1128+ global_samples = self .global_amortizer .sample (inp_global , n_samples , ** kwargs , to_numpy = False )
1129+
1130+ # Repeat local conditions for n_samples
1131+ # New shape -> (num_groups, n_samples, summary_dim_local)
1132+ local_summaries = tf .stack ([tf .squeeze (local_summaries , axis = 0 )] * n_samples , axis = 1 )
1133+
1134+ # Repeat global samples for num_groups
1135+ # New shape -> (num_groups, n_samples, num_globals)
1136+ global_samples_rep = tf .stack ([global_samples ] * num_groups , axis = 0 )
1137+
1138+ # Concatenate local summaries with global samples
1139+ # New shape -> (num_groups, num_samples, summary_dim_local + num_globals)
1140+ local_summaries = tf .concat ([local_summaries , global_samples_rep ], axis = - 1 )
1141+
1142+ # Obtain samples from p(local_i | data_i, global_i)
1143+ inp_local = {DEFAULT_KEYS ["direct_conditions" ]: local_summaries }
1144+ local_samples = self .local_amortizer .sample (inp_local , n_samples , to_numpy = False , ** kwargs )
1145+
1146+ if to_numpy :
1147+ global_samples = global_samples .numpy ()
1148+ local_samples = local_samples .numpy ()
1149+
1150+ return {"global_samples" : global_samples , "local_samples" : local_samples }
1151+
1152+ def log_prob (self , input_dict ):
1153+ """Compute normalized log density."""
1154+
1155+ raise NotImplementedError
1156+
1157+ def _prepare_inputs (self , input_dict , local_summaries , global_summaries ):
1158+ """Prepare input dictionaries for both amortizers."""
1159+
1160+ # Prepare inputs for local amortizer
1161+ local_inputs = {"direct_conditions" : local_summaries , "parameters" : input_dict ["local_parameters" ]}
1162+
1163+ # Prepare inputs for global amortizer
1164+ _parameters = input_dict ["hyper_parameters" ]
1165+ if input_dict .get ("shared_parameters" ) is not None :
1166+ _parameters = tf .concat ([_parameters , input_dict .get ("shared_parameters" )], axis = - 1 )
1167+ global_inputs = {"direct_conditions" : global_summaries , "parameters" : _parameters }
1168+ return local_inputs , global_inputs
1169+
1170+ def _compute_condition (self , input_dict , ** kwargs ):
1171+ """Determines conditionining variables for both amortizers."""
1172+
1173+ # Obtain needed summaries
1174+ local_summaries , global_summaries = self ._get_local_global (input_dict , ** kwargs )
1175+
1176+ # At this point, add globals as conditions
1177+ num_locals = local_summaries .shape [1 ]
1178+
1179+ # Add hyper parameters as conditions:
1180+ # p(local_n | data_n, hyper)
1181+ if input_dict .get ("hyper_parameters" ) is not None :
1182+ _params = input_dict .get ("hyper_parameters" )
1183+ _conds = tf .stack ([_params ] * num_locals , axis = 1 )
1184+ local_summaries = tf .concat ([local_summaries , _conds ], axis = - 1 )
1185+ # Add shared parameters as conditions:
1186+ # p(local_n | data_n, hyper, shared)
1187+ if input_dict .get ("shared_parameters" ) is not None :
1188+ _params = input_dict .get ("shared_parameters" )
1189+ _conds = tf .stack ([_params ] * num_locals , axis = 1 )
1190+ local_summaries = tf .concat ([local_summaries , _conds ], axis = - 1 )
1191+ return local_summaries , global_summaries
1192+
1193+ def _get_local_global (self , input_dict , ** kwargs ):
1194+ """Helper function to obtain local and global condition tensors."""
1195+
1196+ # Obtain summary conditions
1197+ if self .summary_net is not None :
1198+ local_summaries , global_summaries = self .summary_net (
1199+ input_dict ["summary_conditions" ], return_all = True , ** kwargs
1200+ )
1201+ if input_dict .get ("direct_local_conditions" ) is not None :
1202+ local_summaries = tf .concat ([local_summaries , input_dict .get ("direct_local_conditions" )], axis = - 1 )
1203+ if input_dict .get ("direct_global_conditions" ) is not None :
1204+ global_summaries = tf .concat ([global_summaries , input_dict .get ("direct_global_conditions" )], axis = - 1 )
1205+ # If no summary net provided, assume direct conditions exist or fail
1206+ else :
1207+ local_summaries = input_dict .get ("direct_local_conditions" )
1208+ global_summaries = input_dict .get ("direct_global_conditions" )
1209+ return local_summaries , global_summaries
1210+
1211+
10391212class SingleModelAmortizer (AmortizedPosterior ):
10401213 """Deprecated class for amortizer posterior estimation."""
10411214
0 commit comments