Skip to content

Commit 303127d

Browse files
committed
Support log-prob in PointApproximator
1 parent 5cb8995 commit 303127d

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

bayesflow/approximators/point_approximator.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,47 @@ def sample(
111111
if split:
112112
raise NotImplementedError("split=True is currently not supported for `PointApproximator`.")
113113
samples = split_arrays(samples, axis=-1)
114-
# Squeeze samples if there's only one key-value pair.
115-
samples = self._squeeze_samples(samples)
114+
# Squeeze sample dictionary if there's only one key-value pair.
115+
samples = self._squeeze_parametric_score_major_dict(samples)
116116

117117
return samples
118118

119+
def log_prob(
120+
self,
121+
*,
122+
data: dict[str, np.ndarray],
123+
**kwargs,
124+
) -> np.ndarray | dict[str, np.ndarray]:
125+
"""
126+
Computes the log-probability of given data under the parametric distribution(s) for given input conditions.
127+
128+
Parameters
129+
----------
130+
data : dict[str, np.ndarray]
131+
A dictionary mapping variable names to arrays representing the inference conditions and variables.
132+
**kwargs
133+
Additional keyword arguments passed to underlying processing functions.
134+
135+
Returns
136+
-------
137+
log_prob : np.ndarray or dict[str, np.ndarray]
138+
Log-probabilities of the distribution `p(inference_variables | inference_conditions, h(summary_conditions))`
139+
for all parametric scoring rules.
140+
141+
If only one parametric score is available, output is an array of log-probabilities.
142+
143+
Output is a dictionary if multiple parametric scores are available.
144+
Then, each key is the name of a score and values are corresponding log-probabilities.
145+
146+
147+
Log-probabilities have shape (num_datasets,).
148+
"""
149+
log_prob = super().log_prob(data=data, **kwargs)
150+
# Squeeze log probabilities dictionary if there's only one key-value pair.
151+
log_prob = self._squeeze_parametric_score_major_dict(log_prob)
152+
153+
return log_prob
154+
119155
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
120156
"""Adapts and converts the conditions to tensors."""
121157
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
@@ -187,8 +223,10 @@ def _squeeze_estimates(
187223
}
188224
return squeezed
189225

190-
def _squeeze_samples(self, samples: dict[str, np.ndarray]) -> np.ndarray or dict[str, np.ndarray]:
191-
"""Squeezes the samples dictionary to just the value if there is only one key-value pair."""
226+
def _squeeze_parametric_score_major_dict(
227+
self, samples: dict[str, np.ndarray]
228+
) -> np.ndarray or dict[str, np.ndarray]:
229+
"""Squeezes the dictionary to just the value if there is only one key-value pair."""
192230
if len(samples) == 1:
193231
return next(iter(samples.values())) # Extract and return the only item's value
194232
return samples

0 commit comments

Comments
 (0)