Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
DimensionalData = "0.24"
Expand All @@ -21,6 +22,7 @@ Printf = "1.6"
RecipesBase = "1"
ReferenceTests = "0.9, 0.10"
Statistics = "1.6"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PSISResult
psis
ess_is
PSIS.expectation
```

## Plotting
Expand Down
2 changes: 2 additions & 0 deletions src/PSIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module PSIS
using LogExpFunctions: LogExpFunctions
using Printf: @sprintf
using Statistics: Statistics
using StatsBase: StatsBase

export PSISPlots
export PSISResult
Expand All @@ -12,6 +13,7 @@ include("utils.jl")
include("generalized_pareto.jl")
include("core.jl")
include("ess.jl")
include("expectation.jl")
include("recipes/plots.jl")

end
73 changes: 73 additions & 0 deletions src/expectation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
expectation(x, psis_result::PSISResult; kind=Statistics.mean)

Compute the expectation of `x` with respect to the weights in `psis_result`.

# Arguments

- `x`: An array of values of shape `(draws[, chains[, params...]])`, to compute the
expectation of with respect to smoothed importance weights.
- `psis_result`: A `PSISResult` object containing the smoothed importance weights with shape
`(draws[, chains, params...])`.

# Keywords

- `kind=Statistics.mean`: The type of expectation to be computed. It can be any function
that has a method for computing the weighted expectation
`f(x::AbstractVector, weights::AbstractVector) -> Real`. In particular, the following
are supported:

+ `Statistics.mean`
+ `Statistics.median`
+ `Statistics.std`
+ `Statistics.var`
+ `Base.Fix2(Statistics.quantile, p::Real)` for `quantile(x, weights, p)`

# Returns

- `values`: An array of shape `(other..., params...)` or real number of `other` and `params`
are empty containing the expectation of `x` with respect to the smoothed importance
weights.
"""
function expectation(x::AbstractArray, psis_result::PSISResult; kind=Statistics.mean)
log_weights = psis_result.log_weights
weights = psis_result.weights

param_dims = _param_dims(log_weights)
exp_dims = _param_dims(x)
if !isempty(exp_dims) && length(exp_dims) != length(param_dims)
throw(
ArgumentError(
"The trailing dimensions of `x` must match the parameter dimensions of `psis_result.weights`",
),
)
end
param_axes = map(Base.Fix1(axes, log_weights), param_dims)
exp_axes = map(Base.Fix1(axes, x), exp_dims)
if !isempty(exp_axes) && exp_axes != param_axes
throw(
ArgumentError(
"The trailing axes of `x` must match the parameter axes of `psis_result.weights`",
),
)
end

T = Base.promote_eltype(x, log_weights)
values = similar(x, T, param_axes)

for i in _eachparamindex(weights)
w_i = StatsBase.AnalyticWeights(vec(_selectparam(weights, i)), 1)
x_i = vec(ndims(x) < 3 ? x : _selectparam(x, i))
values[i] = _expectation(kind, x_i, w_i)
end

iszero(ndims(values)) && return values[]

return values
end

_expectation(f, x, weights) = f(x, weights)
function _expectation(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x, weights)
prob = f.x
return Statistics.quantile(x, weights, prob)
end