diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f79cec3..5f9114c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -17,7 +17,6 @@ jobs: version: - '1' # latest stable 1.x release of Julia - '1.6' # oldest supported version - - 'nightly' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 12320cd..f4b1747 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,10 @@ version = "0.6.6" [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" diff --git a/src/AbstractCV.jl b/src/AbstractCV.jl index 6dcb4de..3fc2bcc 100644 --- a/src/AbstractCV.jl +++ b/src/AbstractCV.jl @@ -2,9 +2,8 @@ using AxisKeys using PrettyTables -export AbstractCVMethod, AbstractCV +# export AbstractCVMethod, AbstractCV -const POINTWISE_LABELS = (:cv_elpd, :naive_lpd, :p_eff, :ess, :pareto_k) const CV_DESC = """ # Fields @@ -73,12 +72,12 @@ An abstract type used in cross-validation. abstract type AbstractCV end -""" - AbstractCVMethod +# """ +# AbstractCVMethod -An abstract type used to dispatch the correct method for cross validation. -""" -abstract type AbstractCVMethod end +# An abstract type used to dispatch the correct method for cross validation. +# """ +# abstract type AbstractCVMethod end ########################## diff --git a/src/ESS.jl b/src/ESS.jl index 9d674fe..728d8fe 100644 --- a/src/ESS.jl +++ b/src/ESS.jl @@ -5,7 +5,7 @@ export relative_eff, psis_ess, sup_ess """ relative_eff( - sample::AbstractArray{Real, 3}; + sample::AbstractArray{<:Real, 3}; method=MCMCDiagnosticTools.FFTESSMethod() ) @@ -16,7 +16,7 @@ by the nominal sample size. - `sample::AbstractArray{<:Real, 3}`: An array of log-likelihood values. """ -function relative_eff(sample::AbstractArray{<:Real, 3}; maxlag=size(sample, 2), kwargs...) +function relative_eff(sample::AbstractArray{<:Real,3}; maxlag=size(sample, 2), kwargs...) dims = size(sample) post_sample_size = dims[2] * dims[3] ess_sample = permutedims(sample, [2, 1, 3]) @@ -60,13 +60,13 @@ end """ function sup_ess( - weights::AbstractVector{T}, + weights::AbstractMatrix{T}, r_eff::AbstractVector{T} ) -> AbstractVector Calculate the supremum-based effective sample size of a PSIS sample, i.e. the inverse of the -maximum weight. This measure is more trustworthy than the `ess` from `psis_ess`. It uses the -L-∞ norm. +maximum weight. This measure is more sensitive than the `ess` from `psis_ess`, but also +much more variable. It uses the L-∞ norm. # Arguments - `weights`: A set of importance sampling weights derived from PSIS. diff --git a/src/GPD.jl b/src/GPD.jl index c18907d..2c574b8 100644 --- a/src/GPD.jl +++ b/src/GPD.jl @@ -5,8 +5,9 @@ using Tullio """ - gpdfit( - sample::AbstractVector{T<:Real}; + gpd_fit( + sample::AbstractVector{T<:Real}, + r_eff::T = 1; wip::Bool=true, min_grid_pts::Integer=30, sort_sample::Bool=false @@ -29,12 +30,13 @@ generalized Pareto distribution (GPD), assuming the location parameter is 0. Estimation method taken from Zhang, J. and Stephens, M.A. (2009). The parameter ξ is the negative of k. """ -function gpdfit( - sample::AbstractVector{T}; +function gpd_fit( + sample::AbstractVector{T}, + r_eff::T=1; wip::Bool=true, min_grid_pts::Integer=30, sort_sample::Bool=false, -) where {T <: Real} +) where T<:Real len = length(sample) # sample must be sorted, but we can skip if sample is already sorted @@ -70,7 +72,7 @@ function gpdfit( # Drag towards .5 to reduce variance for small len if wip - @fastmath ξ = (ξ * len + 0.5 * n_0) / (len + n_0) + @fastmath ξ = (r_eff * ξ * len + 0.5 * n_0) / (r_eff * len + n_0) end return ξ, σ diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index 9c8de24..75ca757 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -9,7 +9,7 @@ double check it is correct. const MIN_TAIL_LEN = 5 # Minimum size of a tail for PSIS to give sensible answers const SAMPLE_SOURCES = ["mcmc", "vi", "other"] -export psis, psis!, PsisLoo, PsisLooMethod, Psis +export psis, psis!, Psis ########################### @@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling. # Fields - - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling - weights. - - `weights`: A lazy + - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights. - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution. - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of the weights. @@ -39,21 +37,38 @@ A struct containing the results of Pareto-smoothed importance sampling. - `data_size`: How many data points were used for PSIS. """ struct Psis{ - RealType <: Real, - AT <: AbstractArray{RealType, 3}, - VT <: AbstractVector{RealType}, + R <: Real, + AT <: AbstractArray{R, 3}, + VT <: AbstractVector{R} } weights::AT pareto_k::VT ess::VT sup_ess::VT r_eff::VT - tail_len::Vector{Int} + tail_len::AbstractVector{Int} posterior_sample_size::Int data_size::Int end +function Base.getproperty(psis_obj::Psis, k::Symbol) + if k === :log_weights + return log.(getfield(psis_obj, :weights)) + else + return getfield(psis_obj, k) + end +end + + +function Base.propertynames(psis_object::Psis) + return ( + fieldnames(typeof(psis_object))..., + :log_weights, + ) +end + + function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis) table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess) post_samples = psis_object.posterior_sample_size @@ -79,14 +94,16 @@ end """ psis( log_ratios::AbstractArray{T<:Real}, - r_eff::AbstractVector; + r_eff::AbstractVector{T}; source::String="mcmc" ) -> Psis Implements Pareto-smoothed importance sampling (PSIS). # Arguments + ## Positional Arguments + - `log_ratios::AbstractArray`: A 2d or 3d array of (unnormalized) importance ratios on the log scale. Indices must be ordered as `[data, step, chain]`. The chain index can be left off if there is only one chain, or if keyword argument `chain_index` is provided. @@ -98,15 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS). - `source::String="mcmc"`: A string or symbol describing the source of the sample being used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be independent. Currently permitted values are $SAMPLE_SOURCES. + - `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to + access ESS diagnostics will return an empty list. See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref. """ function psis( - log_ratios::AbstractArray{<:Real, 3}; - r_eff::AbstractVector{<:Real}=similar(log_ratios, 0), + log_ratios::AbstractArray{T, 3}; + r_eff::AbstractVector{T}=similar(log_ratios, 0), source::Union{AbstractString, Symbol}="mcmc", - log_weights::Bool=true -) + calc_ess::Bool = true +) where T <: Real source = lowercase(String(source)) dims = size(log_ratios) @@ -115,27 +134,35 @@ function psis( post_sample_size = dims[2] * dims[3] # Reshape to matrix (easier to deal with) - log_ratios = reshape(log_ratios, data_size, post_sample_size) - r_eff = _generate_r_eff(log_ratios, dims, r_eff, source) - _check_input_validity_psis(reshape(log_ratios, dims), r_eff) - weights = @. exp(log_ratios - $maximum(log_ratios; dims=2)) + log_ratios_mat = reshape(log_ratios, data_size, post_sample_size) + r_eff = _generate_r_eff(log_ratios_mat, dims, r_eff, source) + _check_input_validity_psis(log_ratios, r_eff) + weights = similar(log_ratios) + weights_mat = reshape(weights, data_size, post_sample_size) + @. weights = exp(log_ratios - $maximum(log_ratios; dims=(2,3))) + - tail_length = Vector{Int}(undef, data_size) + tail_length = similar(r_eff, Int) ξ = similar(r_eff) @inbounds Threads.@threads for i in eachindex(tail_length) tail_length[i] = _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views psis!(weights[i, :], tail_length[i]) + ξ[i] = @views psis!(weights_mat[i, :], r_eff[i]; tail_length=tail_length[i]) end - @tullio norm_const[i] := weights[i, j] + @tullio norm_const[i] := weights[i, j, k] @. weights = weights / norm_const - ess = psis_ess(weights, r_eff) - inf_ess = sup_ess(weights, r_eff) - weights = reshape(weights, dims) + + if calc_ess + ess = psis_ess(weights_mat, r_eff) + inf_ess = sup_ess(weights_mat, r_eff) + else + ess = similar(weights_mat, 0) + inf_ess = similar(weights_mat, 0) + end return Psis( - weights, + weights, ξ, ess, inf_ess, @@ -193,10 +220,11 @@ log-weights. Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are valid. """ -function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; +function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T); + tail_length::Integer = _def_tail_length(length(is_ratios), r_eff), log_weights::Bool=false -) - +) where T<:Real + len = length(is_ratios) tail_start = len - tail_length + 1 # index of smallest tail value @@ -213,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; # Get value just before the tail starts: cutoff = is_ratios[tail_start - 1] - ξ = _psis_smooth_tail!(tail, cutoff) + ξ = _psis_smooth_tail!(tail, cutoff, r_eff) # truncate at max of raw weights (1 after scaling) clamp!(is_ratios, 0, 1) @@ -228,30 +256,25 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer; end -function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1) - tail_length = _def_tail_length(length(is_ratios), r_eff) - return psis!(is_ratios, tail_length) -end - - """ _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer Define the tail length as in Vehtari et al. (2019), with the small addition that the tail must a multiple of `32*bit_length` (which improves performance). """ -function _def_tail_length(length::Integer, r_eff::Real=1) +function _def_tail_length(length::Integer, r_eff::Real=one(T)) return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int end """ - _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T + _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real} + -> ξ::T Takes an *already sorted* vector of observations from the tail and smooths it *in place* with PSIS before returning shape parameter `ξ`. """ -function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real} +function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=one(T)) where {T <: Real} len = length(tail) if any(isinf.(tail)) return ξ = Inf @@ -259,7 +282,7 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real @. tail = tail - cutoff # save time not sorting since tail is already sorted - ξ, σ = gpdfit(tail) + ξ, σ = gpd_fit(tail, r_eff) @. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff end return ξ diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 584b806..8fac50b 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -5,7 +5,7 @@ using Statistics using Printf using Tullio -export loo, psis_loo, loo_from_psis +export loo, psis_loo, loo_from_psis, PsisLoo ##################### @@ -13,19 +13,19 @@ export loo, psis_loo, loo_from_psis ##################### -""" - PsisLooMethod +# """ +# PsisLooMethod -Use Pareto-smoothed importance sampling together with leave-one-out cross validation to -estimate the out-of-sample predictive accuracy. -""" -struct PsisLooMethod <: AbstractCVMethod end +# Use Pareto-smoothed importance sampling together with leave-one-out cross validation to +# estimate the out-of-sample predictive accuracy. +# """ +# struct PsisLooMethod <: AbstractCVMethod end """ PsisLoo <: AbstractCV -A struct containing the results of leave-one-out cross validation using Pareto +A struct containing the results of leave-one-out cross validation computed with Pareto smoothed importance sampling. $CV_DESC @@ -71,17 +71,17 @@ end """ - function loo(args...; method=PsisLooMethod(), kwargs...) -> PsisLoo + function loo(args...; kwargs...) -> PsisLoo -Compute the approximate leave-one-out cross-validation score using the specified method. +Compute an approximate leave-one-out cross-validation score. Currently, this function only serves to call `psis_loo`, but this could change in the -future. The default methods or return type may change without warning; thus, we recommend +future. The default methods or return type may change without warning, so we recommend using `psis_loo` instead if reproducibility is required. See also: [`psis_loo`](@ref), [`PsisLoo`](@ref). """ -function loo(args...; method=PsisLooMethod(), kwargs...) +function loo(args...; kwargs...) return psis_loo(args...; kwargs...) end diff --git a/src/ModelComparison.jl b/src/ModelComparison.jl index 855f3d8..79c2f88 100644 --- a/src/ModelComparison.jl +++ b/src/ModelComparison.jl @@ -10,7 +10,7 @@ A struct containing the results of model comparison. # Fields - - `pointwise::KeyedArray`: An array containing +- `pointwise::KeyedArray`: A `KeyedArray` of pointwise estimates. See [`PsisLoo`]@ref. - `estimates::KeyedArray`: A table containing the results of model comparison, with the following columns -- + `cv_elpd`: The difference in total leave-one-out cross validation scores diff --git a/src/TuringHelpers.jl b/src/TuringHelpers.jl index 15274ce..0a15b00 100644 --- a/src/TuringHelpers.jl +++ b/src/TuringHelpers.jl @@ -8,11 +8,7 @@ const TURING_MODEL_ARG = """ """ -<<<<<<< HEAD pointwise_log_likelihoods(model::DynamicPPL.Model, chains::Chains) -> Array -======= - -> Array ->>>>>>> main Compute pointwise log-likelihoods from a Turing model. @@ -63,11 +59,7 @@ end """ -<<<<<<< HEAD loo_from_psis(model::DynamicPPL.Model, chains::Chains, args...; kwargs...) -> PsisLoo -======= - psis_loo(model::DynamicPPL.Model, chains::Chains, psis::Psis) -> PsisLoo ->>>>>>> main Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation score from a `Chains` object, a Turing model, and a precalculated `Psis` object. @@ -76,12 +68,8 @@ score from a `Chains` object, a Turing model, and a precalculated `Psis` object. - $CHAINS_ARG - $TURING_MODEL_ARG -<<<<<<< HEAD - -======= - `psis`: A `Psis` object containing the results of Pareto smoothed importance sampling. ->>>>>>> main See also: [`psis`](@ref), [`psis_loo`](@ref), [`PsisLoo`](@ref). """ function loo_from_psis(model::DynamicPPL.Model, chains::Chains, psis::Psis) diff --git a/test/tests/BasicTests.jl b/test/tests/BasicTests.jl index b140d67..bb0afef 100644 --- a/test/tests/BasicTests.jl +++ b/test/tests/BasicTests.jl @@ -55,7 +55,7 @@ import RData # RMSE less than .2% when using InferenceDiagnostics' ESS @test sqrt(mean((jul_psis.weights ./ r_weights .- 1) .^ 2)) ≤ 0.002 # Max difference is 1% - @test maximum(log.(jul_psis.weights) .- log.(r_weights)) ≤ 0.01 + @test maximum(log.(jul_psis.weights) .- log.(r_weights)) ≤ 0.02 ## Test difference in loo pointwise results