Skip to content

Commit 698bb94

Browse files
author
Closed-Limelike-Curves
committed
Add psis! method
1 parent ef35820 commit 698bb94

File tree

2 files changed

+50
-20
lines changed

2 files changed

+50
-20
lines changed

src/ESS.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,29 @@ using Tullio
44
export relative_eff, psis_ess, sup_ess
55

66
"""
7-
relative_eff(sample::AbstractArray{<:Real, 3}; [method])
7+
relative_eff(
8+
sample::AbstractArray{Real, 3};
9+
method=MCMCDiagnosticTools.FFTESSMethod()
10+
)
811
912
Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided
10-
by the nominal sample size. If none is provided, use the default method from
11-
MCMCDiagnosticTools.
13+
by the nominal sample size.
1214
"""
1315
function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...)
1416
dims = size(sample)
1517
post_sample_size = dims[2] * dims[3]
1618
ess_sample = inv.(permutedims(sample, [2, 1, 3]))
17-
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=maxlag, kwargs...)
19+
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; method=method, maxlag=dims[2])
1820
r_eff = ess / post_sample_size
1921
return r_eff
2022
end
2123

2224

2325
"""
2426
function psis_ess(
25-
weights::AbstractVector{<:Real},
26-
r_eff::AbstractVector{<:Real}
27-
) -> AbstractVector{<:Real}
27+
weights::AbstractVector{T<:Real},
28+
r_eff::AbstractVector{T}
29+
) -> AbstractVector{T}
2830
2931
Calculate the (approximate) effective sample size of a PSIS sample, using the correction in
3032
Vehtari et al. 2019. This uses the variance-based definition of ESS, and measures the L2
@@ -37,6 +39,14 @@ distance of the proposal and target distributions.
3739
3840
See `?relative_eff` to calculate `r_eff`.
3941
"""
42+
function psis_ess(
43+
weights::AbstractVector{T}, r_eff::AbstractVector{T}
44+
) where {T <: Union{Real, Missing}}
45+
@tullio sum_of_squares := weights[x]^2
46+
return r_eff ./ sum_of_squares
47+
end
48+
49+
4050
function psis_ess(
4151
weights::AbstractMatrix{T}, r_eff::AbstractVector{T}
4252
) where {T <: Union{Real, Missing}}
@@ -45,7 +55,7 @@ function psis_ess(
4555
end
4656

4757

48-
function psis_ess(weights::AbstractMatrix{<:Real})
58+
function psis_ess(weights::AbstractMatrix{<:Union{Real, Missing}})
4959
@warn "PSIS ESS not adjusted based on MCMC ESS. MCSE and ESS estimates " *
5060
"will be overoptimistic if samples are autocorrelated."
5161
return psis_ess(weights, ones(size(weights)))
@@ -54,8 +64,8 @@ end
5464

5565
"""
5666
function sup_ess(
57-
weights::AbstractVector{<:Real},
58-
r_eff::AbstractVector{<:Real}
67+
weights::AbstractVector{T},
68+
r_eff::AbstractVector{T}
5969
) -> AbstractVector
6070
6171
Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the
@@ -64,8 +74,7 @@ L-∞ norm.
6474
6575
# Arguments
6676
- `weights`: A set of importance sampling weights derived from PSIS.
67-
- `r_eff`: The relative efficiency of the MCMC chains from which PSIS samples were
68-
derived.
77+
- `r_eff`: The relative efficiency of the MCMC chains; see also [`relative_eff`]@ref.
6978
"""
7079
function sup_ess(
7180
weights::AbstractMatrix{T}, r_eff::V

src/ImportanceSampling.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ double check it is correct.
99
const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers
1010
const SAMPLE_SOURCES = ["mcmc", "vi", "other"]
1111

12-
export psis, PsisLoo, PsisLooMethod, Psis
12+
export psis, psis!, PsisLoo, PsisLooMethod, Psis
1313

1414

1515
###########################
@@ -125,7 +125,7 @@ function psis(
125125
ξ = similar(r_eff)
126126
@inbounds Threads.@threads for i in eachindex(tail_length)
127127
tail_length[i] = _def_tail_length(post_sample_size, r_eff[i])
128-
ξ[i] = @views ParetoSmooth._do_psis_i!(weights[i, :], tail_length[i])
128+
ξ[i] = @views psis!(weights[i, :], tail_length[i])
129129
end
130130

131131
@tullio norm_const[i] := weights[i, j]
@@ -159,21 +159,36 @@ function psis(
159159
end
160160

161161

162+
function psis(is_ratios::AbstractVector{<:Real}, args...)
163+
new_ratios = copy(is_ratios)
164+
ξ = psis!(new_ratios)
165+
return new_ratios, ξ
166+
end
167+
168+
169+
162170
"""
163-
_do_psis_i!(is_ratios::AbstractVector{Real}, tail_length::Integer) -> T
171+
psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) -> Real
172+
psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real) -> Real
164173
165-
Do PSIS on a single vector, smoothing its tail values.
174+
Do PSIS on a single vector, smoothing its tail values *in place* before returning the
175+
estimated tail value.
166176
167177
# Arguments
168178
169179
- `is_ratios::AbstractVector{<:Real}`: A vector of importance sampling ratios,
170180
scaled to have a maximum of 1.
181+
- `r_eff::AbstractVector{<:Real}`: A vector of relative effective sample sizes if .
171182
172183
# Returns
173184
174185
- `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails.
186+
187+
# Notes
188+
189+
Unlike `psis`, `psis!` performs no checks to make sure the input values are valid.
175190
"""
176-
function _do_psis_i!(is_ratios::AbstractVector{T}, tail_length::Integer) where {T <: Real}
191+
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
177192

178193
len = length(is_ratios)
179194
tail_start = len - tail_length + 1 # index of smallest tail value
@@ -198,13 +213,19 @@ function _do_psis_i!(is_ratios::AbstractVector{T}, tail_length::Integer) where {
198213
end
199214

200215

216+
function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1)
217+
tail_length = _def_tail_length(length(is_ratios), r_eff)
218+
return psis!(is_ratios, tail_length)
219+
end
220+
221+
201222
"""
202-
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> tail_len::Integer
223+
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
203224
204225
Define the tail length as in Vehtari et al. (2019).
205226
"""
206-
function _def_tail_length(length::Int, r_eff::Real)
207-
return Int(ceil(min(length / 5, 3 * sqrt(length / r_eff))))
227+
function _def_tail_length(length::Integer, r_eff::Real=1)
228+
return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int
208229
end
209230

210231

0 commit comments

Comments
 (0)