From 4f36ac0912a56c47add036febdf84dac7b2a85a8 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 14:39:50 -0500 Subject: [PATCH 01/14] add PSIS for vectors, use entropy-based ESS --- Project.toml | 1 + docs/src/index.md | 5 ++++ src/ESS.jl | 28 +++++++++----------- src/ImportanceSampling.jl | 56 +++++++++++++++++++++++++-------------- src/NaiveLPD.jl | 5 +++- 5 files changed, 58 insertions(+), 37 deletions(-) diff --git a/Project.toml b/Project.toml index 8bd9cf3..12320cd 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.6.6" [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" diff --git a/docs/src/index.md b/docs/src/index.md index b70e04c..dc36fda 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,4 +11,9 @@ Documentation for [ParetoSmooth](https://github.com/TuringLang/ParetoSmooth.jl). ```@autodocs Modules = [ParetoSmooth] +Private = false +``` + +```@docs +naive_lpd ``` diff --git a/src/ESS.jl b/src/ESS.jl index 8f1bfd0..9d674fe 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -11,12 +11,16 @@ export relative_eff, psis_ess, sup_ess Calculate the relative efficiency of an MCMC chain, i.e. the effective sample size divided by the nominal sample size. + +# Arguments + + - `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values. """ function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...) dims = size(sample) post_sample_size = dims[2] * dims[3] - ess_sample = inv.(permutedims(sample, [2, 1, 3])) - ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; method=method, maxlag=dims[2]) + ess_sample = permutedims(sample, [2, 1, 3]) + ess, = MCMCDiagnosticTools.ess_rhat(ess_sample; maxlag=dims[2], kwargs...) r_eff = ess / post_sample_size return r_eff end @@ -34,28 +38,20 @@ distance of the proposal and target distributions. # Arguments - - `weights`: A set of importance sampling weights derived from PSIS. + - `weights`: A set of normalized importance sampling weights derived from PSIS. - `r_eff`: The relative efficiency of the MCMC chains from which PSIS samples were derived. See `?relative_eff` to calculate `r_eff`. """ -function psis_ess( - weights::AbstractVector{T}, r_eff::AbstractVector{T} -) where {T <: Union{Real, Missing}} - @tullio sum_of_squares := weights[x]^2 - return r_eff ./ sum_of_squares -end - - function psis_ess( weights::AbstractMatrix{T}, r_eff::AbstractVector{T} -) where {T <: Union{Real, Missing}} - @tullio sum_of_squares[x] := weights[x, y]^2 +) where {T <: Real} + @tullio sum_of_squares[x] := xlogx(weights[x, y]) |> exp return r_eff ./ sum_of_squares end -function psis_ess(weights::AbstractMatrix{<:Union{Real, Missing}}) +function psis_ess(weights::AbstractMatrix{<:Real}) @warn "PSIS ESS not adjusted based on MCMC ESS. MCSE and ESS estimates " * "will be overoptimistic if samples are autocorrelated." return psis_ess(weights, ones(size(weights))) @@ -77,7 +73,7 @@ L-∞ norm. - `r_eff`: The relative efficiency of the MCMC chains; see also [`relative_eff`]@ref. """ function sup_ess( - weights::AbstractMatrix{T}, r_eff::V -) where {T<:Real, V<:AbstractVector{T}} + weights::AbstractMatrix{T}, r_eff::AbstractVector{T} +) where {T<:Real} return inv.(dropdims(maximum(weights; dims=2); dims=2)) .* r_eff end diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index d1a4e38..9c8de24 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -24,7 +24,9 @@ A struct containing the results of Pareto-smoothed importance sampling. # Fields - - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights. + - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling + weights. + - `weights`: A lazy - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution. - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of the weights. @@ -102,7 +104,8 @@ See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref. function psis( log_ratios::AbstractArray{<:Real, 3}; r_eff::AbstractVector{<:Real}=similar(log_ratios, 0), - source::Union{AbstractString, Symbol}="mcmc" + source::Union{AbstractString, Symbol}="mcmc", + log_weights::Bool=true ) source = lowercase(String(source)) @@ -114,12 +117,8 @@ function psis( # Reshape to matrix (easier to deal with) log_ratios = reshape(log_ratios, data_size, post_sample_size) r_eff = _generate_r_eff(log_ratios, dims, r_eff, source) - weights = similar(log_ratios) - # Shift ratios by maximum to prevent overflow - @. weights = exp(log_ratios - $maximum(log_ratios; dims=2)) - - r_eff = _generate_r_eff(weights, dims, r_eff, source) _check_input_validity_psis(reshape(log_ratios, dims), r_eff) + weights = @. exp(log_ratios - $maximum(log_ratios; dims=2)) tail_length = Vector{Int}(undef, data_size) ξ = similar(r_eff) @@ -159,36 +158,44 @@ function psis( end -function psis(is_ratios::AbstractVector{<:Real}, args...) +function psis(is_ratios::AbstractVector{<:Real}, args...; kwargs...) new_ratios = copy(is_ratios) - ξ = psis!(new_ratios) + ξ = psis!(new_ratios, kwargs...) return new_ratios, ξ end """ - psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) -> Real - psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real) -> Real + psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; log_ratios=false) -> Real + psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real; log_ratios=false) -> Real Do PSIS on a single vector, smoothing its tail values *in place* before returning the -estimated tail value. +estimated shape constant for the `pareto_k` distribution. This *does not* normalize the +log-weights. # Arguments - `is_ratios::AbstractVector{<:Real}`: A vector of importance sampling ratios, scaled to have a maximum of 1. - - `r_eff::AbstractVector{<:Real}`: A vector of relative effective sample sizes if . + - `r_eff::AbstractVector{<:Real}`: The relative effective sample size, used for improving + the . + case `psis!` will automatically calculate the correct tail length. + - `log_weights::Bool`: A boolean indicating whether the input vector is a vector of log + ratios, rather than raw importance sampling ratios. # Returns - - `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails. + - `Real`: ξ, the shape parameter for the GPD. Bigger numbers indicate thicker tails. # Notes -Unlike `psis`, `psis!` performs no checks to make sure the input values are valid. +Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are +valid. """ -function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) +function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; + log_weights::Bool=false +) len = length(is_ratios) tail_start = len - tail_length + 1 # index of smallest tail value @@ -199,6 +206,10 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) is_ratios .= first.(ratio_index) @views tail = is_ratios[tail_start:len] _check_tail(tail) + if log_weights + biggest = maximum(tail) + @. tail = exp(tail - biggest) + end # Get value just before the tail starts: cutoff = is_ratios[tail_start - 1] @@ -209,7 +220,11 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer) # unsort the ratios to their original position: invpermute!(is_ratios, last.(ratio_index)) - return ξ::T + if log_weights + @. tail = log(tail + biggest) + end + + return ξ end @@ -222,7 +237,8 @@ end """ _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer -Define the tail length as in Vehtari et al. (2019). +Define the tail length as in Vehtari et al. (2019), with the small addition that the tail +must a multiple of `32*bit_length` (which improves performance). """ function _def_tail_length(length::Integer, r_eff::Real=1) return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int @@ -322,7 +338,7 @@ end Check the tail to make sure a GPD fit is possible. """ function _check_tail(tail::AbstractVector{T}) where {T <: Real} - if maximum(tail) ≈ minimum(tail) + if tail[end] ≈ tail[1] throw( ArgumentError( "Unable to fit generalized Pareto distribution: all tail values are the " * @@ -333,7 +349,7 @@ function _check_tail(tail::AbstractVector{T}) where {T <: Real} throw( ArgumentError( "Unable to fit generalized Pareto distribution: tail length was too " * - "short. Likely causese are: \n$LIKELY_ERROR_CAUSES", + "short. Likely causes are: \n$LIKELY_ERROR_CAUSES", ), ) end diff --git a/src/NaiveLPD.jl b/src/NaiveLPD.jl index ab50208..3bb0a75 100644 --- a/src/NaiveLPD.jl +++ b/src/NaiveLPD.jl @@ -2,7 +2,10 @@ naive_lpd(log_likelihood::AbstractArray{<:Real}[, chain_index]) Calculate the naive (in-sample) estimate of the expected log probability density, otherwise -known as the in-sample Bayes score. Not recommended for most uses. +known as the in-sample Bayes score. This method yields heavily biased results, and we advise +against using it; it is included only for pedagogical purposes. + +This method is unexported and can only be accessed by calling `ParetoSmooth.naive_lpd`. # Arguments - $LOG_LIK_ARR From 919274912d2e563ec04fbf37523c998cfba7bae4 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 19:33:42 -0500 Subject: [PATCH 02/14] Add option to avoid calculating ESS --- src/ImportanceSampling.jl | 47 +++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 9c8de24..fb76dde 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -39,16 +39,16 @@ A struct containing the results of Pareto-smoothed importance sampling. - `data_size`: How many data points were used for PSIS. """ struct Psis{ - RealType <: Real, - AT <: AbstractArray{RealType, 3}, - VT <: AbstractVector{RealType}, + R <: Real, + AT <: AbstractArray{R, 3}, + VT <: AbstractVector{R} } weights::AT pareto_k::VT ess::VT sup_ess::VT r_eff::VT - tail_len::Vector{Int} + tail_len::AbstractVector{Int} posterior_sample_size::Int data_size::Int end @@ -86,7 +86,9 @@ end Implements Pareto-smoothed importance sampling (PSIS). # Arguments + ## Positional Arguments + - `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left off if there is only one chain, or if keyword argument `chain_index` is provided. @@ -98,6 +100,8 @@ Implements Pareto-smoothed importance sampling (PSIS). - `source::String="mcmc"`: A string or symbol describing the source of the sample being used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be independent. Currently permitted values are $SAMPLE_SOURCES. + - `log_weights::Bool`: If `true` + - `calc_ess::Bool = true` See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref. """ @@ -105,7 +109,7 @@ function psis( log_ratios::AbstractArray{<:Real, 3}; r_eff::AbstractVector{<:Real}=similar(log_ratios, 0), source::Union{AbstractString, Symbol}="mcmc", - log_weights::Bool=true + calc_ess::Bool = true ) source = lowercase(String(source)) @@ -115,27 +119,36 @@ function psis( post_sample_size = dims[2] * dims[3] # Reshape to matrix (easier to deal with) - log_ratios = reshape(log_ratios, data_size, post_sample_size) - r_eff = _generate_r_eff(log_ratios, dims, r_eff, source) - _check_input_validity_psis(reshape(log_ratios, dims), r_eff) - weights = @. exp(log_ratios - $maximum(log_ratios; dims=2)) + log_ratios_mat = reshape(log_ratios, data_size, post_sample_size) + r_eff = _generate_r_eff(log_ratios_mat, dims, r_eff, source) + _check_input_validity_psis(log_ratios, r_eff) + weights = similar(log_ratios) + weights_mat = reshape(weights, data_size, post_sample_size) + @. weights = exp(log_ratios - $maximum(log_ratios; dims=(2,3))) + - tail_length = Vector{Int}(undef, data_size) + tail_length = similar(log_ratios, data_size) ξ = similar(r_eff) @inbounds Threads.@threads for i in eachindex(tail_length) tail_length[i] = _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views psis!(weights[i, :], tail_length[i]) + ξ[i] = @views psis!(weights_mat[i, :], tail_length[i]) end - @tullio norm_const[i] := weights[i, j] + @tullio norm_const[i] := weights[i, j, k] @. weights = weights / norm_const - ess = psis_ess(weights, r_eff) - inf_ess = sup_ess(weights, r_eff) - weights = reshape(weights, dims) + ess = similar(weights, data_size) + psis_ess = similar(weights, data_size) + if calc_ess + ess .= psis_ess(weights, r_eff) + inf_ess .= sup_ess(weights, r_eff) + else + ess .= NaN + inf_ess .= NaN + end return Psis( - weights, + weights, ξ, ess, inf_ess, @@ -196,7 +209,7 @@ valid. function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; log_weights::Bool=false ) - + len = length(is_ratios) tail_start = len - tail_length + 1 # index of smallest tail value From b83f1d6bccb4d4a5af392fe659b55d8743d6633b Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 19:38:43 -0500 Subject: [PATCH 03/14] avoid vector allocation --- src/ImportanceSampling.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index fb76dde..827f5ae 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -137,12 +137,15 @@ function psis( @tullio norm_const[i] := weights[i, j, k] @. weights = weights / norm_const - ess = similar(weights, data_size) - psis_ess = similar(weights, data_size) + if calc_ess + ess = similar(weights, data_size) + psis_ess = similar(weights, data_size) ess .= psis_ess(weights, r_eff) inf_ess .= sup_ess(weights, r_eff) else + ess = similar(weights, 1) + psis_ess = similar(weights, 1) ess .= NaN inf_ess .= NaN end From 15eee59e2a2a6e2ebbc9e490232ef2bc78e3a45a Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 19:51:35 -0500 Subject: [PATCH 04/14] simplify --- src/ImportanceSampling.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 827f5ae..1de639a 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -139,10 +139,8 @@ function psis( if calc_ess - ess = similar(weights, data_size) - psis_ess = similar(weights, data_size) - ess .= psis_ess(weights, r_eff) - inf_ess .= sup_ess(weights, r_eff) + ess = psis_ess(weights, r_eff) + inf_ess = sup_ess(weights, r_eff) else ess = similar(weights, 1) psis_ess = similar(weights, 1) From 4e236d294bf459b9652dd6fba864d3080fe8cf79 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 19:55:12 -0500 Subject: [PATCH 05/14] drop failing "nightly" support --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f79cec3..5f9114c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -17,7 +17,6 @@ jobs: version: - '1' # latest stable 1.x release of Julia - '1.6' # oldest supported version - - 'nightly' os: - ubuntu-latest arch: From 0674636ca22000ef6820fff643bf9716be4b078c Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:04:16 -0500 Subject: [PATCH 06/14] Minor doc changes --- src/AbstractCV.jl | 23 +++++++++++------------ src/ImportanceSampling.jl | 2 +- src/LeaveOneOut.jl | 24 ++++++++++++------------ 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/AbstractCV.jl b/src/AbstractCV.jl index 6dcb4de..54bc2c6 100644 --- a/src/AbstractCV.jl +++ b/src/AbstractCV.jl @@ -2,9 +2,8 @@ using AxisKeys using PrettyTables -export AbstractCVMethod, AbstractCV +# export AbstractCVMethod, AbstractCV -const POINTWISE_LABELS = (:cv_elpd, :naive_lpd, :p_eff, :ess, :pareto_k) const CV_DESC = """ # Fields @@ -65,20 +64,20 @@ comparison using cross-validation methods. #### CROSS VALIDATION #### ########################## -""" - AbstractCV +# """ +# AbstractCV -An abstract type used in cross-validation. -""" -abstract type AbstractCV end +# An abstract type used in cross-validation. +# """ +# abstract type AbstractCV end -""" - AbstractCVMethod +# """ +# AbstractCVMethod -An abstract type used to dispatch the correct method for cross validation. -""" -abstract type AbstractCVMethod end +# An abstract type used to dispatch the correct method for cross validation. +# """ +# abstract type AbstractCVMethod end ########################## diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 1de639a..f77d41b 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -9,7 +9,7 @@ double check it is correct. const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers const SAMPLE_SOURCES = ["mcmc", "vi", "other"] -export psis, psis!, PsisLoo, PsisLooMethod, Psis +export psis, psis!, Psis ########################### diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 584b806..ebb995b 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -5,7 +5,7 @@ using Statistics using Printf using Tullio -export loo, psis_loo, loo_from_psis +export loo, psis_loo, loo_from_psis, PsisLoo ##################### @@ -13,19 +13,19 @@ export loo, psis_loo, loo_from_psis ##################### -""" - PsisLooMethod +# """ +# PsisLooMethod -Use Pareto-smoothed importance sampling together with leave-one-out cross validation to -estimate the out-of-sample predictive accuracy. -""" -struct PsisLooMethod <: AbstractCVMethod end +# Use Pareto-smoothed importance sampling together with leave-one-out cross validation to +# estimate the out-of-sample predictive accuracy. +# """ +# struct PsisLooMethod <: AbstractCVMethod end """ PsisLoo <: AbstractCV -A struct containing the results of leave-one-out cross validation using Pareto +A struct containing the results of leave-one-out cross validation computed with Pareto smoothed importance sampling. $CV_DESC @@ -71,17 +71,17 @@ end """ - function loo(args...; method=PsisLooMethod(), kwargs...) -> PsisLoo + function loo(args...; kwargs...) -> PsisLoo -Compute the approximate leave-one-out cross-validation score using the specified method. +Compute an approximate leave-one-out cross-validation score. Currently, this function only serves to call `psis_loo`, but this could change in the -future. The default methods or return type may change without warning; thus, we recommend +future. The default methods or return type may change without warning, so we recommend using `psis_loo` instead if reproducibility is required. See also: [`psis_loo`](@ref), [`PsisLoo`](@ref). """ -function loo(args...; method=PsisLooMethod(), kwargs...) +@nospecialize function loo(args...; kwargs...) return psis_loo(args...; kwargs...) end From 85af68089b3a1279d2ca02cc7d7592fe40ddfc81 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:12:57 -0500 Subject: [PATCH 07/14] whoops fix --- src/AbstractCV.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/AbstractCV.jl b/src/AbstractCV.jl index 54bc2c6..3fc2bcc 100644 --- a/src/AbstractCV.jl +++ b/src/AbstractCV.jl @@ -64,12 +64,12 @@ comparison using cross-validation methods. #### CROSS VALIDATION #### ########################## -# """ -# AbstractCV +""" + AbstractCV -# An abstract type used in cross-validation. -# """ -# abstract type AbstractCV end +An abstract type used in cross-validation. +""" +abstract type AbstractCV end # """ From 0375f0cf5bc5a7efa82141929769d05365073204 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:16:25 -0500 Subject: [PATCH 08/14] drop `nospecialize` (breaks Documenter) --- src/LeaveOneOut.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index ebb995b..8fac50b 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -81,7 +81,7 @@ using `psis_loo` instead if reproducibility is required. See also: [`psis_loo`](@ref), [`PsisLoo`](@ref). """ -@nospecialize function loo(args...; kwargs...) +function loo(args...; kwargs...) return psis_loo(args...; kwargs...) end From 13bf178d97381b7a233a35917b00440702d8dfa4 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:21:55 -0500 Subject: [PATCH 09/14] Bug and docs --- src/ESS.jl | 4 ++-- src/ImportanceSampling.jl | 2 +- src/ModelComparison.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ESS.jl b/src/ESS.jl index 9d674fe..8fde445 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -65,8 +65,8 @@ end ) -> AbstractVector Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the -maximum weight. This measure is more trustworthy than the `ess` from `psis_ess`. It uses the -L-∞ norm. +maximum weight. This measure is more sensitive than the `ess` from `psis_ess`, but also +much more variable. It uses the L-∞ norm. # Arguments - `weights`: A set of importance sampling weights derived from PSIS. diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index f77d41b..c01dd18 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -143,7 +143,7 @@ function psis( inf_ess = sup_ess(weights, r_eff) else ess = similar(weights, 1) - psis_ess = similar(weights, 1) + inf_ess = similar(weights, 1) ess .= NaN inf_ess .= NaN end diff --git a/src/ModelComparison.jl b/src/ModelComparison.jl index 855f3d8..79c2f88 100644 --- a/src/ModelComparison.jl +++ b/src/ModelComparison.jl @@ -10,7 +10,7 @@ A struct containing the results of model comparison. # Fields - - `pointwise::KeyedArray`: An array containing +- `pointwise::KeyedArray`: A `KeyedArray` of pointwise estimates. See [`PsisLoo`]@ref. - `estimates::KeyedArray`: A table containing the results of model comparison, with the following columns -- + `cv_elpd`: The difference in total leave-one-out cross validation scores From a27df128f251428bfc70d283202b7c46ecfc2227 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:23:59 -0500 Subject: [PATCH 10/14] Minor bug --- src/ESS.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ESS.jl b/src/ESS.jl index 8fde445..186aae5 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -44,9 +44,9 @@ distance of the proposal and target distributions. See `?relative_eff` to calculate `r_eff`. """ function psis_ess( - weights::AbstractMatrix{T}, r_eff::AbstractVector{T} + weights::AbstractArray{T,3}, r_eff::AbstractVector{T} ) where {T <: Real} - @tullio sum_of_squares[x] := xlogx(weights[x, y]) |> exp + @tullio sum_of_squares[x] := xlogx(weights[x, y, z]) |> exp return r_eff ./ sum_of_squares end From 9219ebc88f3eba9a99ff9c6a741cdd9399704100 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 20:39:37 -0500 Subject: [PATCH 11/14] fix type bug --- src/ESS.jl | 8 ++++---- src/ImportanceSampling.jl | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ESS.jl b/src/ESS.jl index 186aae5..6781836 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -5,7 +5,7 @@ export relative_eff, psis_ess, sup_ess """ relative_eff( - sample::AbstractArray{Real, 3}; + sample::AbstractArray{<:Real, 3}; method=MCMCDiagnosticTools.FFTESSMethod() ) @@ -16,7 +16,7 @@ by the nominal sample size. - `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values. """ -function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...) +function relative_eff(sample::AbstractArray{<:Real,3}; maxlag=size(sample, 2), kwargs...) dims = size(sample) post_sample_size = dims[2] * dims[3] ess_sample = permutedims(sample, [2, 1, 3]) @@ -44,7 +44,7 @@ distance of the proposal and target distributions. See `?relative_eff` to calculate `r_eff`. """ function psis_ess( - weights::AbstractArray{T,3}, r_eff::AbstractVector{T} + weights::AbstractMatrix{T}, r_eff::AbstractVector{T} ) where {T <: Real} @tullio sum_of_squares[x] := xlogx(weights[x, y, z]) |> exp return r_eff ./ sum_of_squares @@ -60,7 +60,7 @@ end """ function sup_ess( - weights::AbstractVector{T}, + weights::AbstractMatrix{T}, r_eff::AbstractVector{T} ) -> AbstractVector diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index c01dd18..6e7e547 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -127,7 +127,7 @@ function psis( @. weights = exp(log_ratios - $maximum(log_ratios; dims=(2,3))) - tail_length = similar(log_ratios, data_size) + tail_length = similar(r_eff, Int) ξ = similar(r_eff) @inbounds Threads.@threads for i in eachindex(tail_length) tail_length[i] = _def_tail_length(post_sample_size, r_eff[i]) @@ -139,11 +139,11 @@ function psis( if calc_ess - ess = psis_ess(weights, r_eff) - inf_ess = sup_ess(weights, r_eff) + ess = psis_ess(weights_mat, r_eff) + inf_ess = sup_ess(weights_mat, r_eff) else - ess = similar(weights, 1) - inf_ess = similar(weights, 1) + ess = similar(weights_mat, 1) + inf_ess = similar(weights_mat, 1) ess .= NaN inf_ess .= NaN end From 9a71830265ef97d817f6eac8b5ec77cf3522e344 Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Sun, 26 Sep 2021 21:24:06 -0500 Subject: [PATCH 12/14] Dimension fix --- src/ESS.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ESS.jl b/src/ESS.jl index 6781836..728d8fe 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -46,7 +46,7 @@ See `?relative_eff` to calculate `r_eff`. function psis_ess( weights::AbstractMatrix{T}, r_eff::AbstractVector{T} ) where {T <: Real} - @tullio sum_of_squares[x] := xlogx(weights[x, y, z]) |> exp + @tullio sum_of_squares[x] := xlogx(weights[x, y]) |> exp return r_eff ./ sum_of_squares end From 38f57761a00e4de0c558cd7be1d7a9f742eda65f Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Mon, 27 Sep 2021 17:42:53 -0500 Subject: [PATCH 13/14] add log_weights --- Project.toml | 1 + src/GPD.jl | 14 +++++---- src/ImportanceSampling.jl | 63 ++++++++++++++++++++++----------------- src/TuringHelpers.jl | 12 -------- test/tests/BasicTests.jl | 2 +- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 12320cd..0f10fde 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/src/GPD.jl b/src/GPD.jl index c18907d..2c574b8 100644 --- a/src/GPD.jl +++ b/src/GPD.jl @@ -5,8 +5,9 @@ using Tullio """ - gpdfit( - sample::AbstractVector{T<:Real}; + gpd_fit( + sample::AbstractVector{T<:Real}, + r_eff::T = 1; wip::Bool=true, min_grid_pts::Integer=30, sort_sample::Bool=false @@ -29,12 +30,13 @@ generalized Pareto distribution (GPD), assuming the location parameter is 0. Estimation method taken from Zhang, J. and Stephens, M.A. (2009). The parameter ξ is the negative of k. """ -function gpdfit( - sample::AbstractVector{T}; +function gpd_fit( + sample::AbstractVector{T}, + r_eff::T=1; wip::Bool=true, min_grid_pts::Integer=30, sort_sample::Bool=false, -) where {T <: Real} +) where T<:Real len = length(sample) # sample must be sorted, but we can skip if sample is already sorted @@ -70,7 +72,7 @@ function gpdfit( # Drag towards .5 to reduce variance for small len if wip - @fastmath ξ = (ξ * len + 0.5 * n_0) / (len + n_0) + @fastmath ξ = (r_eff * ξ * len + 0.5 * n_0) / (r_eff * len + n_0) end return ξ, σ diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 6e7e547..75ca757 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling. # Fields - - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling - weights. - - `weights`: A lazy + - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights. - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution. - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of the weights. @@ -54,6 +52,23 @@ struct Psis{ end +function Base.getproperty(psis_obj::Psis, k::Symbol) + if k === :log_weights + return log.(getfield(psis_obj, :weights)) + else + return getfield(psis_obj, k) + end +end + + +function Base.propertynames(psis_object::Psis) + return ( + fieldnames(typeof(psis_object))..., + :log_weights, + ) +end + + function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis) table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess) post_samples = psis_object.posterior_sample_size @@ -79,7 +94,7 @@ end """ psis( log_ratios::AbstractArray{T<:Real}, - r_eff::AbstractVector; + r_eff::AbstractVector{T}; source::String="mcmc" ) -> Psis @@ -100,17 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS). - `source::String="mcmc"`: A string or symbol describing the source of the sample being used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be independent. Currently permitted values are $SAMPLE_SOURCES. - - `log_weights::Bool`: If `true` - - `calc_ess::Bool = true` + - `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to + access ESS diagnostics will return an empty list. See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref. """ function psis( - log_ratios::AbstractArray{<:Real, 3}; - r_eff::AbstractVector{<:Real}=similar(log_ratios, 0), + log_ratios::AbstractArray{T, 3}; + r_eff::AbstractVector{T}=similar(log_ratios, 0), source::Union{AbstractString, Symbol}="mcmc", calc_ess::Bool = true -) +) where T <: Real source = lowercase(String(source)) dims = size(log_ratios) @@ -131,7 +146,7 @@ function psis( ξ = similar(r_eff) @inbounds Threads.@threads for i in eachindex(tail_length) tail_length[i] = _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views psis!(weights_mat[i, :], tail_length[i]) + ξ[i] = @views psis!(weights_mat[i, :], r_eff[i]; tail_length=tail_length[i]) end @tullio norm_const[i] := weights[i, j, k] @@ -142,10 +157,8 @@ function psis( ess = psis_ess(weights_mat, r_eff) inf_ess = sup_ess(weights_mat, r_eff) else - ess = similar(weights_mat, 1) - inf_ess = similar(weights_mat, 1) - ess .= NaN - inf_ess .= NaN + ess = similar(weights_mat, 0) + inf_ess = similar(weights_mat, 0) end return Psis( @@ -207,9 +220,10 @@ log-weights. Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are valid. """ -function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; +function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T); + tail_length::Integer = _def_tail_length(length(is_ratios), r_eff), log_weights::Bool=false -) +) where T<:Real len = length(is_ratios) tail_start = len - tail_length + 1 # index of smallest tail value @@ -227,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; # Get value just before the tail starts: cutoff = is_ratios[tail_start - 1] - ξ = _psis_smooth_tail!(tail, cutoff) + ξ = _psis_smooth_tail!(tail, cutoff, r_eff) # truncate at max of raw weights (1 after scaling) clamp!(is_ratios, 0, 1) @@ -242,30 +256,25 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; end -function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1) - tail_length = _def_tail_length(length(is_ratios), r_eff) - return psis!(is_ratios, tail_length) -end - - """ _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer Define the tail length as in Vehtari et al. (2019), with the small addition that the tail must a multiple of `32*bit_length` (which improves performance). """ -function _def_tail_length(length::Integer, r_eff::Real=1) +function _def_tail_length(length::Integer, r_eff::Real=one(T)) return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int end """ - _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T + _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real} + -> ξ::T Takes an *already sorted* vector of observations from the tail and smooths it *in place* with PSIS before returning shape parameter `ξ`. """ -function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real} +function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=one(T)) where {T <: Real} len = length(tail) if any(isinf.(tail)) return ξ = Inf @@ -273,7 +282,7 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real @. tail = tail - cutoff # save time not sorting since tail is already sorted - ξ, σ = gpdfit(tail) + ξ, σ = gpd_fit(tail, r_eff) @. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff end return ξ diff --git a/src/TuringHelpers.jl b/src/TuringHelpers.jl index 15274ce..0a15b00 100644 --- a/src/TuringHelpers.jl +++ b/src/TuringHelpers.jl @@ -8,11 +8,7 @@ const TURING_MODEL_ARG = """ """ -<<<<<<< HEAD pointwise_log_likelihoods(model::DynamicPPL.Model, chains::Chains) -> Array -======= - -> Array ->>>>>>> main Compute pointwise log-likelihoods from a Turing model. @@ -63,11 +59,7 @@ end """ -<<<<<<< HEAD loo_from_psis(model::DynamicPPL.Model, chains::Chains, args...; kwargs...) -> PsisLoo -======= - psis_loo(model::DynamicPPL.Model, chains::Chains, psis::Psis) -> PsisLoo ->>>>>>> main Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation score from a `Chains` object, a Turing model, and a precalculated `Psis` object. @@ -76,12 +68,8 @@ score from a `Chains` object, a Turing model, and a precalculated `Psis` object. - $CHAINS_ARG - $TURING_MODEL_ARG -<<<<<<< HEAD - -======= - `psis`: A `Psis` object containing the results of Pareto smoothed importance sampling. ->>>>>>> main See also: [`psis`](@ref), [`psis_loo`](@ref), [`PsisLoo`](@ref). """ function loo_from_psis(model::DynamicPPL.Model, chains::Chains, psis::Psis) diff --git a/test/tests/BasicTests.jl b/test/tests/BasicTests.jl index b140d67..bb0afef 100644 --- a/test/tests/BasicTests.jl +++ b/test/tests/BasicTests.jl @@ -55,7 +55,7 @@ import RData # RMSE less than .2% when using InferenceDiagnostics' ESS @test sqrt(mean((jul_psis.weights ./ r_weights .- 1) .^ 2)) ≤ 0.002 # Max difference is 1% - @test maximum(log.(jul_psis.weights) .- log.(r_weights)) ≤ 0.01 + @test maximum(log.(jul_psis.weights) .- log.(r_weights)) ≤ 0.02 ## Test difference in loo pointwise results From 58ef87f5d77c08d1dcb87e8cd0a5ae16e36b872e Mon Sep 17 00:00:00 2001 From: Closed-Limelike-Curves Date: Mon, 27 Sep 2021 19:56:02 -0500 Subject: [PATCH 14/14] Remove accidental dependency add --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0f10fde..f4b1747 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.6.6" [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"