Skip to content

Commit 8650285

Browse files
author
Carlos Parada
authored
add PSIS for vectors, use entropy-based ESS (#50)
1 parent 698bb94 commit 8650285

File tree

5 files changed

+58
-37
lines changed

5 files changed

+58
-37
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.6.6"
66
[deps]
77
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
9+
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1112
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"

docs/src/index.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,9 @@ Documentation for [ParetoSmooth](https://github.com/TuringLang/ParetoSmooth.jl).
1111

1212
```@autodocs
1313
Modules = [ParetoSmooth]
14+
Private = false
15+
```
16+
17+
```@docs
18+
naive_lpd
1419
```

src/ESS.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ export relative_eff, psis_ess, sup_ess
1111
1212
Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided
1313
by the nominal sample size.
14+
15+
# Arguments
16+
17+
- `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values.
1418
"""
1519
function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...)
1620
dims = size(sample)
1721
post_sample_size = dims[2] * dims[3]
18-
ess_sample = inv.(permutedims(sample, [2, 1, 3]))
19-
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; method=method, maxlag=dims[2])
22+
ess_sample = permutedims(sample, [2, 1, 3])
23+
ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=dims[2], kwargs...)
2024
r_eff = ess / post_sample_size
2125
return r_eff
2226
end
@@ -34,28 +38,20 @@ distance of the proposal and target distributions.
3438
3539
# Arguments
3640
37-
- `weights`: A set of importance sampling weights derived from PSIS.
41+
- `weights`: A set of normalized importance sampling weights derived from PSIS.
3842
- `r_eff`: The relative efficiency of the MCMC chains from which PSIS samples were derived.
3943
4044
See `?relative_eff` to calculate `r_eff`.
4145
"""
42-
function psis_ess(
43-
weights::AbstractVector{T}, r_eff::AbstractVector{T}
44-
) where {T <: Union{Real, Missing}}
45-
@tullio sum_of_squares := weights[x]^2
46-
return r_eff ./ sum_of_squares
47-
end
48-
49-
5046
function psis_ess(
5147
weights::AbstractMatrix{T}, r_eff::AbstractVector{T}
52-
) where {T <: Union{Real, Missing}}
53-
@tullio sum_of_squares[x] := weights[x, y]^2
48+
) where {T <: Real}
49+
@tullio sum_of_squares[x] := xlogx(weights[x, y]) |> exp
5450
return r_eff ./ sum_of_squares
5551
end
5652

5753

58-
function psis_ess(weights::AbstractMatrix{<:Union{Real, Missing}})
54+
function psis_ess(weights::AbstractMatrix{<:Real})
5955
@warn "PSIS ESS not adjusted based on MCMC ESS. MCSE and ESS estimates " *
6056
"will be overoptimistic if samples are autocorrelated."
6157
return psis_ess(weights, ones(size(weights)))
@@ -77,7 +73,7 @@ L-∞ norm.
7773
- `r_eff`: The relative efficiency of the MCMC chains; see also [`relative_eff`]@ref.
7874
"""
7975
function sup_ess(
80-
weights::AbstractMatrix{T}, r_eff::V
81-
) where {T<:Real, V<:AbstractVector{T}}
76+
weights::AbstractMatrix{T}, r_eff::AbstractVector{T}
77+
) where {T<:Real}
8278
return inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff
8379
end

src/ImportanceSampling.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525
# Fields
2626
27-
- `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
27+
- `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28+
weights.
29+
- `weights`: A lazy
2830
- `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
2931
- `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3032
the weights.
@@ -102,7 +104,8 @@ See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
102104
function psis(
103105
log_ratios::AbstractArray{<:Real, 3};
104106
r_eff::AbstractVector{<:Real}=similar(log_ratios, 0),
105-
source::Union{AbstractString, Symbol}="mcmc"
107+
source::Union{AbstractString, Symbol}="mcmc",
108+
log_weights::Bool=true
106109
)
107110

108111
source = lowercase(String(source))
@@ -114,12 +117,8 @@ function psis(
114117
# Reshape to matrix (easier to deal with)
115118
log_ratios = reshape(log_ratios, data_size, post_sample_size)
116119
r_eff = _generate_r_eff(log_ratios, dims, r_eff, source)
117-
weights = similar(log_ratios)
118-
# Shift ratios by maximum to prevent overflow
119-
@. weights = exp(log_ratios - $maximum(log_ratios; dims=2))
120-
121-
r_eff = _generate_r_eff(weights, dims, r_eff, source)
122120
_check_input_validity_psis(reshape(log_ratios, dims), r_eff)
121+
weights = @. exp(log_ratios - $maximum(log_ratios; dims=2))
123122

124123
tail_length = Vector{Int}(undef, data_size)
125124
ξ = similar(r_eff)
@@ -159,36 +158,44 @@ function psis(
159158
end
160159

161160

162-
function psis(is_ratios::AbstractVector{<:Real}, args...)
161+
function psis(is_ratios::AbstractVector{<:Real}, args...; kwargs...)
163162
new_ratios = copy(is_ratios)
164-
ξ = psis!(new_ratios)
163+
ξ = psis!(new_ratios, kwargs...)
165164
return new_ratios, ξ
166165
end
167166

168167

169168

170169
"""
171-
psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) -> Real
172-
psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real) -> Real
170+
psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; log_ratios=false) -> Real
171+
psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real; log_ratios=false) -> Real
173172
174173
Do PSIS on a single vector, smoothing its tail values *in place* before returning the
175-
estimated tail value.
174+
estimated shape constant for the `pareto_k` distribution. This *does not* normalize the
175+
log-weights.
176176
177177
# Arguments
178178
179179
- `is_ratios::AbstractVector{<:Real}`: A vector of importance sampling ratios,
180180
scaled to have a maximum of 1.
181-
- `r_eff::AbstractVector{<:Real}`: A vector of relative effective sample sizes if .
181+
- `r_eff::AbstractVector{<:Real}`: The relative effective sample size, used for improving
182+
the .
183+
case `psis!` will automatically calculate the correct tail length.
184+
- `log_weights::Bool`: A boolean indicating whether the input vector is a vector of log
185+
ratios, rather than raw importance sampling ratios.
182186
183187
# Returns
184188
185-
- `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails.
189+
- `Real`: ξ, the shape parameter for the GPD. Bigger numbers indicate thicker tails.
186190
187191
# Notes
188192
189-
Unlike `psis`, `psis!` performs no checks to make sure the input values are valid.
193+
Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
194+
valid.
190195
"""
191-
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
196+
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
197+
log_weights::Bool=false
198+
)
192199

193200
len = length(is_ratios)
194201
tail_start = len - tail_length + 1 # index of smallest tail value
@@ -199,6 +206,10 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
199206
is_ratios .= first.(ratio_index)
200207
@views tail = is_ratios[tail_start:len]
201208
_check_tail(tail)
209+
if log_weights
210+
biggest = maximum(tail)
211+
@. tail = exp(tail - biggest)
212+
end
202213

203214
# Get value just before the tail starts:
204215
cutoff = is_ratios[tail_start - 1]
@@ -209,7 +220,11 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer)
209220
# unsort the ratios to their original position:
210221
invpermute!(is_ratios, last.(ratio_index))
211222

212-
return ξ::T
223+
if log_weights
224+
@. tail = log(tail + biggest)
225+
end
226+
227+
return ξ
213228
end
214229

215230

@@ -222,7 +237,8 @@ end
222237
"""
223238
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
224239
225-
Define the tail length as in Vehtari et al. (2019).
240+
Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
241+
must a multiple of `32*bit_length` (which improves performance).
226242
"""
227243
function _def_tail_length(length::Integer, r_eff::Real=1)
228244
return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int
@@ -322,7 +338,7 @@ end
322338
Check the tail to make sure a GPD fit is possible.
323339
"""
324340
function _check_tail(tail::AbstractVector{T}) where {T <: Real}
325-
if maximum(tail) minimum(tail)
341+
if tail[end] tail[1]
326342
throw(
327343
ArgumentError(
328344
"Unable to fit generalized Pareto distribution: all tail values are the " *
@@ -333,7 +349,7 @@ function _check_tail(tail::AbstractVector{T}) where {T <: Real}
333349
throw(
334350
ArgumentError(
335351
"Unable to fit generalized Pareto distribution: tail length was too " *
336-
"short. Likely causese are: \n$LIKELY_ERROR_CAUSES",
352+
"short. Likely causes are: \n$LIKELY_ERROR_CAUSES",
337353
),
338354
)
339355
end

src/NaiveLPD.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
naive_lpd(log_likelihood::AbstractArray{<:Real}[, chain_index])
33
44
Calculate the naive (in-sample) estimate of the expected log probability density, otherwise
5-
known as the in-sample Bayes score. Not recommended for most uses.
5+
known as the in-sample Bayes score. This method yields heavily biased results, and we advise
6+
against using it; it is included only for pedagogical purposes.
7+
8+
This method is unexported and can only be accessed by calling `ParetoSmooth.naive_lpd`.
69
710
# Arguments
811
- $LOG_LIK_ARR

0 commit comments

Comments
 (0)