@@ -111,11 +111,47 @@ def sample(
111111 if split :
112112 raise NotImplementedError ("split=True is currently not supported for `PointApproximator`." )
113113 samples = split_arrays (samples , axis = - 1 )
114- # Squeeze samples if there's only one key-value pair.
115- samples = self ._squeeze_samples (samples )
114+ # Squeeze sample dictionary if there's only one key-value pair.
115+ samples = self ._squeeze_parametric_score_major_dict (samples )
116116
117117 return samples
118118
119+ def log_prob (
120+ self ,
121+ * ,
122+ data : dict [str , np .ndarray ],
123+ ** kwargs ,
124+ ) -> np .ndarray | dict [str , np .ndarray ]:
125+ """
126+ Computes the log-probability of given data under the parametric distribution(s) for given input conditions.
127+
128+ Parameters
129+ ----------
130+ data : dict[str, np.ndarray]
131+ A dictionary mapping variable names to arrays representing the inference conditions and variables.
132+ **kwargs
133+ Additional keyword arguments passed to underlying processing functions.
134+
135+ Returns
136+ -------
137+ log_prob : np.ndarray or dict[str, np.ndarray]
138+ Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
139+ for all parametric scoring rules.
140+
141+ If only one parametric score is available, output is an array of log-probabilities.
142+
143+ Output is a dictionary if multiple parametric scores are available.
144+ Then, each key is the name of a score and values are corresponding log-probabilities.
145+
146+
147+ Log-probabilities have shape (num_datasets,).
148+ """
149+ log_prob = super ().log_prob (data = data , ** kwargs )
150+ # Squeeze log probabilities dictionary if there's only one key-value pair.
151+ log_prob = self ._squeeze_parametric_score_major_dict (log_prob )
152+
153+ return log_prob
154+
119155 def _prepare_conditions (self , conditions : dict [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
120156 """Adapts and converts the conditions to tensors."""
121157 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
@@ -187,8 +223,10 @@ def _squeeze_estimates(
187223 }
188224 return squeezed
189225
190- def _squeeze_samples (self , samples : dict [str , np .ndarray ]) -> np .ndarray or dict [str , np .ndarray ]:
191- """Squeezes the samples dictionary to just the value if there is only one key-value pair."""
226+ def _squeeze_parametric_score_major_dict (
227+ self , samples : dict [str , np .ndarray ]
228+ ) -> np .ndarray or dict [str , np .ndarray ]:
229+ """Squeezes the dictionary to just the value if there is only one key-value pair."""
192230 if len (samples ) == 1 :
193231 return next (iter (samples .values ())) # Extract and return the only item's value
194232 return samples
0 commit comments