@@ -209,39 +209,53 @@ def create_config(cls, **kwargs):
209209
210210class EvidentialNetwork (tf .keras .Model ):
211211 """Implements a network whose outputs are the concentration parameters of a Dirichlet density.
212+
213+ Follows ideas from:
212214
213- Follows the implementation from:
214- https://arxiv.org/abs/2004.10629
215+ [1] Radev, S. T., D'Alessandro, M., Mertens, U. K., Voss, A., Köthe, U., & Bürkner, P. C. (2021).
216+ Amortized Bayesian model comparison with evidential deep learning.
217+ IEEE Transactions on Neural Networks and Learning Systems.
218+
219+ [2] Sensoy, M., Kaplan, L., & Kandemir, M. (2018).
220+ Evidential deep learning to quantify classification uncertainty.
221+ Advances in neural information processing systems, 31.
215222 """
216223
217- def __init__ (self , meta = {} ):
224+ def __init__ (self , num_models , dense_args = None , num_dense = 3 , output_activation = 'softplus' , ** kwargs ):
218225 """Creates an instance of an evidential network for amortized model comparison.
219226
220227 Parameters
221228 ----------
222- meta : dict
223- A list of dictionaries, where each dictionary holds parameter-value pairs
224- for a single :class:`tf.keras.Dense` layer
229+ num_models : int
230+ The number of candidate (competing models) for the comparison scenario.
231+ dense_args : dict or None, optional, default: None
232+ The arguments for a tf.keras.layers.Dense layer. If None, defaults will be used.
233+ num_dense : int, optional, default: 3
234+ The number of dense layers for the main network part.
235+ output_activation : str or callable, optional, default: 'softplus'
236+ The activation function to use for the network outputs.
237+ Important: needs to have positive outputs.
238+ **kwargs : dict, optional, default: {}
239+ Optional keyword arguments (e.g., name) passed to the tf.keras.Model __init__ method.
225240 """
226241
227- super ().__init__ ()
242+ super ().__init__ (** kwargs )
228243
229- # Create settings dictionary
230- meta = build_meta_dict (user_dict = meta ,
231- default_setting = default_settings .DEFAULT_SETTING_EVIDENTIAL_NET )
244+ if dense_args is None :
245+ dense_args = default_settings .DEFAULT_SETTINGS_DENSE_EVIDENTIAL
232246
233247 # A network to increase representation power
234248 self .dense = tf .keras .Sequential ([
235- tf .keras .layers .Dense (** meta [ ' dense_args' ] )
236- for _ in range (meta [ 'n_dense' ] )
249+ tf .keras .layers .Dense (** dense_args )
250+ for _ in range (num_dense )
237251 ])
238252
239253 # The layer to output model evidences
240- self .evidence_layer = tf .keras .layers .Dense (
241- meta [ 'n_models' ] , activation = meta [ ' output_activation' ] ,
242- ** {k : v for k , v in meta [ ' dense_args' ] .items () if k != 'units' and k != 'activation' })
254+ self .alpha_layer = tf .keras .layers .Dense (
255+ num_models , activation = output_activation ,
256+ ** {k : v for k , v in dense_args .items () if k != 'units' and k != 'activation' })
243257
244- self .n_models = meta [ 'n_models' ]
258+ self .num_models = num_models
245259
246260 def call (self , condition , ** kwargs ):
247261 """Computes evidences for model comparison given a batch of data and optional concatenated context,
@@ -254,13 +268,17 @@ def call(self, condition, **kwargs):
254268
255269 Returns
256270 -------
257- alpha : tf.Tensor of shape (batch_size, n_models ) -- the learned model evidences
271+ evidence : tf.Tensor of shape (batch_size, num_models ) -- the learned model evidences
258272 """
259273
274+ return self .evidence (condition , ** kwargs )
275+
276+ @tf .function
277+ def evidence (self , condition , ** kwargs ):
260278 rep = self .dense (condition , ** kwargs )
261- evidence = self .evidence_layer (rep , ** kwargs )
262- alpha = evidence + 1
263- return alpha
279+ alpha = self .alpha_layer (rep , ** kwargs )
280+ evidence = alpha + 1.
281+ return evidence
264282
265283 def sample (self , condition , n_samples , ** kwargs ):
266284 """Samples posterior model probabilities from the higher-order Dirichlet density.
@@ -271,17 +289,24 @@ def sample(self, condition, n_samples, **kwargs):
271289 The summary of the observed (or simulated) data, shape (n_data_sets, ...)
272290 n_samples : int
273291 Number of samples to obtain from the approximate posterior
274-
292+
275293 Returns
276294 -------
277295 pm_samples : tf.Tensor or np.array
278- The posterior draws from the Dirichlet distribution, shape (n_samples, n_batch, n_models )
296+ The posterior draws from the Dirichlet distribution, shape (num_samples, num_batch, num_models )
279297 """
280298
281- # Compute evidential values
282- alpha = self (condition , ** kwargs )
299+ alpha = self .evidence (condition , ** kwargs )
283300 n_datasets = alpha .shape [0 ]
284-
285- # Sample for each dataset
286- pm_samples = np .stack ([np .random .dirichlet (alpha [n , :], size = n_samples ) for n in range (n_datasets )], axis = 1 )
301+ pm_samples = np .stack (
302+ [np .default_rng ().dirichlet (alpha [n , :], size = n_samples ) for n in range (n_datasets )], axis = 1 )
287303 return pm_samples
304+
305+ @classmethod
306+ def create_config (cls , ** kwargs ):
307+ """"Used to create the settings dictionary for the internal networks of the invertible
308+ network. Will fill in missing """
309+
310+ settings = build_meta_dict (user_dict = kwargs ,
311+ default_setting = default_settings .DEFAULT_SETTING_EVIDENTIAL_NET )
312+ return settings
0 commit comments