diff --git a/Project.toml b/Project.toml index 91a56c24..3adcad0c 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] DimensionalData = "0.24.2" @@ -22,6 +23,7 @@ RecipesBase = "1" ReferenceTests = "0.9, 0.10" StableRNGs = "1" Statistics = "1.6" +StatsBase = "0.33.1, 0.34" julia = "1.6" [extras] diff --git a/docs/src/api.md b/docs/src/api.md index 5e7b69c5..bb66f12a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -6,6 +6,7 @@ PSISResult psis ess_is +PSIS.expectation ``` ## Plotting diff --git a/src/PSIS.jl b/src/PSIS.jl index f249a936..a0c5c821 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -3,6 +3,7 @@ module PSIS using LogExpFunctions: LogExpFunctions using Printf: @sprintf using Statistics: Statistics +using StatsBase: StatsBase export PSISPlots export PSISResult @@ -12,6 +13,7 @@ include("utils.jl") include("generalized_pareto.jl") include("core.jl") include("ess.jl") +include("expectation.jl") include("recipes/plots.jl") end diff --git a/src/expectation.jl b/src/expectation.jl new file mode 100644 index 00000000..2f1f2600 --- /dev/null +++ b/src/expectation.jl @@ -0,0 +1,73 @@ +""" + expectation(x, psis_result::PSISResult; kind=Statistics.mean) + +Compute the expectation of `x` with respect to the weights in `psis_result`. + +# Arguments + + - `x`: An array of values of shape `(draws[, chains[, params...]])`, to compute the + expectation of with respect to smoothed importance weights. + - `psis_result`: A `PSISResult` object containing the smoothed importance weights with shape + `(draws[, chains, params...])`. + +# Keywords + + - `kind=Statistics.mean`: The type of expectation to be computed. It can be any function + that has a method for computing the weighted expectation + `f(x::AbstractVector, weights::AbstractVector) -> Real`. In particular, the following + are supported: + + + `Statistics.mean` + + `Statistics.median` + + `Statistics.std` + + `Statistics.var` + + `Base.Fix2(Statistics.quantile, p::Real)` for `quantile(x, weights, p)` + +# Returns + + - `values`: An array of shape `(other..., params...)` or real number of `other` and `params` + are empty containing the expectation of `x` with respect to the smoothed importance + weights. +""" +function expectation(x::AbstractArray, psis_result::PSISResult; kind=Statistics.mean) + log_weights = psis_result.log_weights + weights = psis_result.weights + + param_dims = _param_dims(log_weights) + exp_dims = _param_dims(x) + if !isempty(exp_dims) && length(exp_dims) != length(param_dims) + throw( + ArgumentError( + "The trailing dimensions of `x` must match the parameter dimensions of `psis_result.weights`", + ), + ) + end + param_axes = map(Base.Fix1(axes, log_weights), param_dims) + exp_axes = map(Base.Fix1(axes, x), exp_dims) + if !isempty(exp_axes) && exp_axes != param_axes + throw( + ArgumentError( + "The trailing axes of `x` must match the parameter axes of `psis_result.weights`", + ), + ) + end + + T = Base.promote_eltype(x, log_weights) + values = similar(x, T, param_axes) + + for i in _eachparamindex(weights) + w_i = StatsBase.AnalyticWeights(vec(_selectparam(weights, i)), 1) + x_i = vec(ndims(x) < 3 ? x : _selectparam(x, i)) + values[i] = _expectation(kind, x_i, w_i) + end + + iszero(ndims(values)) && return values[] + + return values +end + +_expectation(f, x, weights) = f(x, weights) +function _expectation(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x, weights) + prob = f.x + return Statistics.quantile(x, weights, prob) +end