Skip to content

Commit 5b1c4ab

Browse files
authored
optional normalization for psislw (#2455)
1 parent 1b0b9cb commit 5b1c4ab

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

arviz/stats/stats.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
869869
)
870870

871871

872-
def psislw(log_weights, reff=1.0):
872+
def psislw(log_weights, reff=1.0, normalize=True):
873873
"""
874874
Pareto smoothed importance sampling (PSIS).
875875
@@ -887,11 +887,13 @@ def psislw(log_weights, reff=1.0):
887887
Array of size (n_observations, n_samples)
888888
reff : float, default 1
889889
relative MCMC efficiency, ``ess / n``
890+
normalize : bool, default True
891+
return normalized log weights
890892
891893
Returns
892894
-------
893895
lw_out : DataArray or (..., N) ndarray
894-
Smoothed, truncated and normalized log weights.
896+
Smoothed, truncated and possibly normalized log weights.
895897
kss : DataArray or (...) ndarray
896898
Estimates of the shape parameter *k* of the generalized Pareto
897899
distribution.
@@ -936,7 +938,12 @@ def psislw(log_weights, reff=1.0):
936938
out = np.empty_like(log_weights), np.empty(shape)
937939

938940
# define kwargs
939-
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "out": out}
941+
func_kwargs = {
942+
"cutoff_ind": cutoff_ind,
943+
"cutoffmin": cutoffmin,
944+
"out": out,
945+
"normalize": normalize,
946+
}
940947
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
941948
kwargs = {"input_core_dims": [["__sample__"]], "output_core_dims": [["__sample__"], []]}
942949
log_weights, pareto_shape = _wrap_xarray_ufunc(
@@ -953,7 +960,7 @@ def psislw(log_weights, reff=1.0):
953960
return log_weights, pareto_shape
954961

955962

956-
def _psislw(log_weights, cutoff_ind, cutoffmin):
963+
def _psislw(log_weights, cutoff_ind, cutoffmin, normalize):
957964
"""
958965
Pareto smoothed importance sampling (PSIS) for a 1D vector.
959966
@@ -963,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
963970
Array of length n_observations
964971
cutoff_ind: int
965972
cutoffmin: float
966-
k_min: float
973+
normalize: bool
967974
968975
Returns
969976
-------
@@ -975,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
975982
x = np.asarray(log_weights)
976983

977984
# improve numerical accuracy
978-
x -= np.max(x)
985+
max_x = np.max(x)
986+
x -= max_x
979987
# sort the array
980988
x_sort_ind = np.argsort(x)
981989
# divide log weights into body and right tail
@@ -1007,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
10071015
x[tailinds[x_tail_si]] = smoothed_tail
10081016
# truncate smoothed values to the largest raw weight 0
10091017
x[x > 0] = 0
1018+
10101019
# renormalize weights
1011-
x -= _logsumexp(x)
1020+
if normalize:
1021+
x -= _logsumexp(x)
1022+
else:
1023+
x += max_x
10121024

10131025
return x, k
10141026

0 commit comments

Comments
 (0)