@@ -869,7 +869,7 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
869869 )
870870
871871
872- def psislw (log_weights , reff = 1.0 ):
872+ def psislw (log_weights , reff = 1.0 , normalize = True ):
873873 """
874874 Pareto smoothed importance sampling (PSIS).
875875
@@ -887,11 +887,13 @@ def psislw(log_weights, reff=1.0):
887887 Array of size (n_observations, n_samples)
888888 reff : float, default 1
889889 relative MCMC efficiency, ``ess / n``
890+ normalize : bool, default True
891+ return normalized log weights
890892
891893 Returns
892894 -------
893895 lw_out : DataArray or (..., N) ndarray
894- Smoothed, truncated and normalized log weights.
896+ Smoothed, truncated and possibly normalized log weights.
895897 kss : DataArray or (...) ndarray
896898 Estimates of the shape parameter *k* of the generalized Pareto
897899 distribution.
@@ -936,7 +938,12 @@ def psislw(log_weights, reff=1.0):
936938 out = np .empty_like (log_weights ), np .empty (shape )
937939
938940 # define kwargs
939- func_kwargs = {"cutoff_ind" : cutoff_ind , "cutoffmin" : cutoffmin , "out" : out }
941+ func_kwargs = {
942+ "cutoff_ind" : cutoff_ind ,
943+ "cutoffmin" : cutoffmin ,
944+ "out" : out ,
945+ "normalize" : normalize ,
946+ }
940947 ufunc_kwargs = {"n_dims" : 1 , "n_output" : 2 , "ravel" : False , "check_shape" : False }
941948 kwargs = {"input_core_dims" : [["__sample__" ]], "output_core_dims" : [["__sample__" ], []]}
942949 log_weights , pareto_shape = _wrap_xarray_ufunc (
@@ -953,7 +960,7 @@ def psislw(log_weights, reff=1.0):
953960 return log_weights , pareto_shape
954961
955962
956- def _psislw (log_weights , cutoff_ind , cutoffmin ):
963+ def _psislw (log_weights , cutoff_ind , cutoffmin , normalize ):
957964 """
958965 Pareto smoothed importance sampling (PSIS) for a 1D vector.
959966
@@ -963,7 +970,7 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
963970 Array of length n_observations
964971 cutoff_ind: int
965972 cutoffmin: float
966- k_min: float
973+ normalize: bool
967974
968975 Returns
969976 -------
@@ -975,7 +982,8 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
975982 x = np .asarray (log_weights )
976983
977984 # improve numerical accuracy
978- x -= np .max (x )
985+ max_x = np .max (x )
986+ x -= max_x
979987 # sort the array
980988 x_sort_ind = np .argsort (x )
981989 # divide log weights into body and right tail
@@ -1007,8 +1015,12 @@ def _psislw(log_weights, cutoff_ind, cutoffmin):
10071015 x [tailinds [x_tail_si ]] = smoothed_tail
10081016 # truncate smoothed values to the largest raw weight 0
10091017 x [x > 0 ] = 0
1018+
10101019 # renormalize weights
1011- x -= _logsumexp (x )
1020+ if normalize :
1021+ x -= _logsumexp (x )
1022+ else :
1023+ x += max_x
10121024
10131025 return x , k
10141026
0 commit comments