@@ -24,7 +24,9 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525# Fields
2626
27- - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
27+ - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28+ weights.
29+ - `weights`: A lazy
2830 - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
2931 - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3032 the weights.
@@ -102,7 +104,8 @@ See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
102104function psis (
103105 log_ratios:: AbstractArray{<:Real, 3} ;
104106 r_eff:: AbstractVector{<:Real} = similar (log_ratios, 0 ),
105- source:: Union{AbstractString, Symbol} = " mcmc"
107+ source:: Union{AbstractString, Symbol} = " mcmc" ,
108+ log_weights:: Bool = true
106109)
107110
108111 source = lowercase (String (source))
@@ -114,12 +117,8 @@ function psis(
114117 # Reshape to matrix (easier to deal with)
115118 log_ratios = reshape (log_ratios, data_size, post_sample_size)
116119 r_eff = _generate_r_eff (log_ratios, dims, r_eff, source)
117- weights = similar (log_ratios)
118- # Shift ratios by maximum to prevent overflow
119- @. weights = exp (log_ratios - $ maximum (log_ratios; dims= 2 ))
120-
121- r_eff = _generate_r_eff (weights, dims, r_eff, source)
122120 _check_input_validity_psis (reshape (log_ratios, dims), r_eff)
121+ weights = @. exp (log_ratios - $ maximum (log_ratios; dims= 2 ))
123122
124123 tail_length = Vector {Int} (undef, data_size)
125124 ξ = similar (r_eff)
@@ -159,36 +158,44 @@ function psis(
159158end
160159
161160
162- function psis (is_ratios:: AbstractVector{<:Real} , args... )
161+ function psis (is_ratios:: AbstractVector{<:Real} , args... ; kwargs ... )
163162 new_ratios = copy (is_ratios)
164- ξ = psis! (new_ratios)
163+ ξ = psis! (new_ratios, kwargs ... )
165164 return new_ratios, ξ
166165end
167166
168167
169168
170169"""
171- psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) -> Real
172- psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real) -> Real
170+ psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; log_ratios=false ) -> Real
171+ psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real; log_ratios=false ) -> Real
173172
174173Do PSIS on a single vector, smoothing its tail values *in place* before returning the
175- estimated tail value.
174+ estimated shape constant for the `pareto_k` distribution. This *does not* normalize the
175+ log-weights.
176176
177177# Arguments
178178
179179 - `is_ratios::AbstractVector{<:Real}`: A vector of importance sampling ratios,
180180 scaled to have a maximum of 1.
181- - `r_eff::AbstractVector{<:Real}`: A vector of relative effective sample sizes if .
181+ - `r_eff::AbstractVector{<:Real}`: The relative effective sample size, used for improving
182+ the .
183+ case `psis!` will automatically calculate the correct tail length.
184+ - `log_weights::Bool`: A boolean indicating whether the input vector is a vector of log
185+ ratios, rather than raw importance sampling ratios.
182186
183187# Returns
184188
185- - `T<: Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails.
189+ - `Real`: ξ, the shape parameter for the GPD. Bigger numbers indicate thicker tails.
186190
187191# Notes
188192
189- Unlike `psis`, `psis!` performs no checks to make sure the input values are valid.
193+ Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
194+ valid.
190195"""
191- function psis! (is_ratios:: AbstractVector{<:Real} , tail_length:: Integer )
196+ function psis! (is_ratios:: AbstractVector{<:Real} , tail_length:: Integer ;
197+ log_weights:: Bool = false
198+ )
192199
193200 len = length (is_ratios)
194201 tail_start = len - tail_length + 1 # index of smallest tail value
@@ -199,6 +206,10 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
199206 is_ratios .= first .(ratio_index)
200207 @views tail = is_ratios[tail_start: len]
201208 _check_tail (tail)
209+ if log_weights
210+ biggest = maximum (tail)
211+ @. tail = exp (tail - biggest)
212+ end
202213
203214 # Get value just before the tail starts:
204215 cutoff = is_ratios[tail_start - 1 ]
@@ -209,7 +220,11 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
209220 # unsort the ratios to their original position:
210221 invpermute! (is_ratios, last .(ratio_index))
211222
212- return ξ:: T
223+ if log_weights
224+ @. tail = log (tail + biggest)
225+ end
226+
227+ return ξ
213228end
214229
215230
222237"""
223238 _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
224239
225- Define the tail length as in Vehtari et al. (2019).
240+ Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
241+ must a multiple of `32*bit_length` (which improves performance).
226242"""
227243function _def_tail_length (length:: Integer , r_eff:: Real = 1 )
228244 return min (cld (length, 5 ), ceil (3 * sqrt (length / r_eff))) |> Int
322338Check the tail to make sure a GPD fit is possible.
323339"""
324340function _check_tail (tail:: AbstractVector{T} ) where {T <: Real }
325- if maximum ( tail) ≈ minimum ( tail)
341+ if tail[ end ] ≈ tail[ 1 ]
326342 throw (
327343 ArgumentError (
328344 " Unable to fit generalized Pareto distribution: all tail values are the " *
@@ -333,7 +349,7 @@ function _check_tail(tail::AbstractVector{T}) where {T <: Real}
333349 throw (
334350 ArgumentError (
335351 " Unable to fit generalized Pareto distribution: tail length was too " *
336- " short. Likely causese are: \n $LIKELY_ERROR_CAUSES " ,
352+ " short. Likely causes are: \n $LIKELY_ERROR_CAUSES " ,
337353 ),
338354 )
339355 end
0 commit comments