@@ -86,6 +86,10 @@ class AmortizedPosterior(tf.keras.Model, AmortizedTarget):
8686 Tails of lipschitz triangular flows.
8787 In International Conference on Machine Learning (pp. 4673-4681). PMLR.
8888
89+ [4] Alexanderson, S., & Henter, G. E. (2020).
90+ Robust model training and generalisation with Studentising flows.
91+ arXiv preprint arXiv:2006.06599.
92+
8993 Serves as in interface for learning ``p(parameters | data, context).``
9094 """
9195
@@ -662,7 +666,11 @@ def _determine_latent_dist(self, latent_dist):
662666
663667class AmortizedPosteriorLikelihood (tf .keras .Model , AmortizedTarget ):
664668 """An interface for jointly learning a surrogate model of the simulator and an approximate
665- posterior given a generative model.
669+ posterior given a generative model, as proposed by:
670+
671+ [1] Radev, S. T., Schmitt, M., Pratz, V., Picchini, U., Köthe, U., & Bürkner, P. C. (2023).
672+ JANA: Jointly Amortized Neural Approximation of Complex Bayesian Models.
673+ arXiv preprint arXiv:2302.09125.
666674 """
667675
668676 def __init__ (self , amortized_posterior , amortized_likelihood , ** kwargs ):
@@ -671,9 +679,11 @@ def __init__(self, amortized_posterior, amortized_likelihood, **kwargs):
671679 Parameters
672680 ----------
673681 amortized_posterior : an instance of AmortizedPosterior or a custom tf.keras.Model
674- The generative neural posterior approximator.
682+ The generative neural posterior approximator
675683 amortized_likelihood : an instance of AmortizedLikelihood or a custom tf.keras.Model
676- The generative neural likelihood approximator.
684+ The generative neural likelihood approximator
685+ **kwargs : dict, optional, default: {}
686+ Additional keyword arguments passed to the ``__init__`` method of a ``tf.keras.Model`` instance
677687 """
678688
679689 tf .keras .Model .__init__ (self , ** kwargs )
@@ -878,7 +888,7 @@ class AmortizedModelComparison(tf.keras.Model):
878888 arXiv preprint arXiv:2301.11873.
879889
880890 Note: the original paper [1] does not distinguish between the summary and the evidential networks, but
881- treats them as a whole, with the appropriate architetcure dictated by the model application. For the
891+ treats them as a whole, with the appropriate architecture dictated by the model application. For the
882892 sake of consistency and modularity, the BayesFlow library separates the two constructs.
883893 """
884894
@@ -954,6 +964,9 @@ def posterior_probs(self, input_dict, to_numpy=True, **kwargs):
954964 `direct_conditions` - the conditioning variables that the directly passed to the evidential network
955965 to_numpy : bool, optional, default: True
956966 Flag indicating whether to return the PMPs a ``np.ndarray`` or a ``tf.Tensor``
967+ **kwargs : dict, optional, default: {}
968+ Additional keyword arguments passed to the networks
969+
957970 Returns
958971 -------
959972 out : tf.Tensor of shape (batch_size, ..., num_models)
@@ -991,7 +1004,7 @@ def compute_loss(self, input_dict, **kwargs):
9911004 return loss
9921005
9931006 def _compute_summary_condition (self , summary_conditions , direct_conditions , ** kwargs ):
994- """Determines how to concatenate the provided conditions."""
1007+ """Helper method to determines how to concatenate the provided conditions."""
9951008
9961009 # Compute learnable summaries, if given
9971010 if self .summary_net is not None :
0 commit comments