@@ -9,7 +9,7 @@ double check it is correct.
99const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers
1010const SAMPLE_SOURCES = [" mcmc" , " vi" , " other" ]
1111
12- export psis, psis!, PsisLoo, PsisLooMethod, Psis
12+ export psis, psis!, Psis
1313
1414
1515# ##########################
@@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525# Fields
2626
27- - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28- weights.
29- - `weights`: A lazy
27+ - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
3028 - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
3129 - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3230 the weights.
@@ -39,21 +37,38 @@ A struct containing the results of Pareto-smoothed importance sampling.
3937 - `data_size`: How many data points were used for PSIS.
4038"""
4139struct Psis{
42- RealType <: Real ,
43- AT <: AbstractArray{RealType , 3} ,
44- VT <: AbstractVector{RealType} ,
40+ R <: Real ,
41+ AT <: AbstractArray{R , 3} ,
42+ VT <: AbstractVector{R}
4543}
4644 weights:: AT
4745 pareto_k:: VT
4846 ess:: VT
4947 sup_ess:: VT
5048 r_eff:: VT
51- tail_len:: Vector {Int}
49+ tail_len:: AbstractVector {Int}
5250 posterior_sample_size:: Int
5351 data_size:: Int
5452end
5553
5654
55+ function Base. getproperty (psis_obj:: Psis , k:: Symbol )
56+ if k === :log_weights
57+ return log .(getfield (psis_obj, :weights ))
58+ else
59+ return getfield (psis_obj, k)
60+ end
61+ end
62+
63+
64+ function Base. propertynames (psis_object:: Psis )
65+ return (
66+ fieldnames (typeof (psis_object))... ,
67+ :log_weights ,
68+ )
69+ end
70+
71+
5772function Base. show (io:: IO , :: MIME"text/plain" , psis_object:: Psis )
5873 table = hcat (psis_object. pareto_k, psis_object. ess, psis_object. sup_ess)
5974 post_samples = psis_object. posterior_sample_size
7994"""
8095 psis(
8196 log_ratios::AbstractArray{T<:Real},
82- r_eff::AbstractVector;
97+ r_eff::AbstractVector{T} ;
8398 source::String="mcmc"
8499 ) -> Psis
85100
86101Implements Pareto-smoothed importance sampling (PSIS).
87102
88103# Arguments
104+
89105## Positional Arguments
106+
90107 - `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the
91108 log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left
92109 off if there is only one chain, or if keyword argument `chain_index` is provided.
@@ -98,15 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
98115 - `source::String="mcmc"`: A string or symbol describing the source of the sample being
99116 used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
100117 independent. Currently permitted values are $SAMPLE_SOURCES .
118+ - `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to
119+ access ESS diagnostics will return an empty list.
101120
102121See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
103122"""
104123function psis (
105- log_ratios:: AbstractArray{<:Real , 3} ;
106- r_eff:: AbstractVector{<:Real } = similar (log_ratios, 0 ),
124+ log_ratios:: AbstractArray{T , 3} ;
125+ r_eff:: AbstractVector{T } = similar (log_ratios, 0 ),
107126 source:: Union{AbstractString, Symbol} = " mcmc" ,
108- log_weights :: Bool = true
109- )
127+ calc_ess :: Bool = true
128+ ) where T <: Real
110129
111130 source = lowercase (String (source))
112131 dims = size (log_ratios)
@@ -115,27 +134,35 @@ function psis(
115134 post_sample_size = dims[2 ] * dims[3 ]
116135
117136 # Reshape to matrix (easier to deal with)
118- log_ratios = reshape (log_ratios, data_size, post_sample_size)
119- r_eff = _generate_r_eff (log_ratios, dims, r_eff, source)
120- _check_input_validity_psis (reshape (log_ratios, dims), r_eff)
121- weights = @. exp (log_ratios - $ maximum (log_ratios; dims= 2 ))
137+ log_ratios_mat = reshape (log_ratios, data_size, post_sample_size)
138+ r_eff = _generate_r_eff (log_ratios_mat, dims, r_eff, source)
139+ _check_input_validity_psis (log_ratios, r_eff)
140+ weights = similar (log_ratios)
141+ weights_mat = reshape (weights, data_size, post_sample_size)
142+ @. weights = exp (log_ratios - $ maximum (log_ratios; dims= (2 ,3 )))
143+
122144
123- tail_length = Vector {Int} (undef, data_size )
145+ tail_length = similar (r_eff, Int )
124146 ξ = similar (r_eff)
125147 @inbounds Threads. @threads for i in eachindex (tail_length)
126148 tail_length[i] = _def_tail_length (post_sample_size, r_eff[i])
127- ξ[i] = @views psis! (weights [i, :], tail_length[i])
149+ ξ[i] = @views psis! (weights_mat [i, :], r_eff[i]; tail_length = tail_length[i])
128150 end
129151
130- @tullio norm_const[i] := weights[i, j]
152+ @tullio norm_const[i] := weights[i, j, k ]
131153 @. weights = weights / norm_const
132- ess = psis_ess (weights, r_eff)
133- inf_ess = sup_ess (weights, r_eff)
134154
135- weights = reshape (weights, dims)
155+
156+ if calc_ess
157+ ess = psis_ess (weights_mat, r_eff)
158+ inf_ess = sup_ess (weights_mat, r_eff)
159+ else
160+ ess = similar (weights_mat, 0 )
161+ inf_ess = similar (weights_mat, 0 )
162+ end
136163
137164 return Psis (
138- weights,
165+ weights,
139166 ξ,
140167 ess,
141168 inf_ess,
@@ -193,10 +220,11 @@ log-weights.
193220Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
194221valid.
195222"""
196- function psis! (is_ratios:: AbstractVector{<:Real} , tail_length:: Integer ;
223+ function psis! (is_ratios:: AbstractVector{T} , r_eff:: T = one (T);
224+ tail_length:: Integer = _def_tail_length (length (is_ratios), r_eff),
197225 log_weights:: Bool = false
198- )
199-
226+ ) where T <: Real
227+
200228 len = length (is_ratios)
201229 tail_start = len - tail_length + 1 # index of smallest tail value
202230
@@ -213,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
213241
214242 # Get value just before the tail starts:
215243 cutoff = is_ratios[tail_start - 1 ]
216- ξ = _psis_smooth_tail! (tail, cutoff)
244+ ξ = _psis_smooth_tail! (tail, cutoff, r_eff )
217245
218246 # truncate at max of raw weights (1 after scaling)
219247 clamp! (is_ratios, 0 , 1 )
@@ -228,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
228256end
229257
230258
231- function psis! (is_ratios:: AbstractVector{<:Real} , r_eff:: Real = 1 )
232- tail_length = _def_tail_length (length (is_ratios), r_eff)
233- return psis! (is_ratios, tail_length)
234- end
235-
236-
237259"""
238260 _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
239261
240262Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
241263must a multiple of `32*bit_length` (which improves performance).
242264"""
243- function _def_tail_length (length:: Integer , r_eff:: Real = 1 )
265+ function _def_tail_length (length:: Integer , r_eff:: Real = one (T) )
244266 return min (cld (length, 5 ), ceil (3 * sqrt (length / r_eff))) |> Int
245267end
246268
247269
248270"""
249- _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
271+ _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
272+ -> ξ::T
250273
251274Takes an *already sorted* vector of observations from the tail and smooths it *in place*
252275with PSIS before returning shape parameter `ξ`.
253276"""
254- function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T ) where {T <: Real }
277+ function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T , r_eff :: T = one (T) ) where {T <: Real }
255278 len = length (tail)
256279 if any (isinf .(tail))
257280 return ξ = Inf
258281 else
259282 @. tail = tail - cutoff
260283
261284 # save time not sorting since tail is already sorted
262- ξ, σ = gpdfit (tail)
285+ ξ, σ = gpd_fit (tail, r_eff )
263286 @. tail = gpd_quantile (($ (1 : len) - 0.5 ) / len, ξ, σ) + cutoff
264287 end
265288 return ξ
0 commit comments