Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "InferenceObjects"
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.4.12"
version = "0.4.13"

[deps]
ANSIColoredPrinters = "a4c015fc-c6ff-483c-b24f-f7ea428134e9"
Expand Down Expand Up @@ -32,7 +32,7 @@ MLJBase = "1"
NCDatasets = "0.12.6, 0.13, 0.14"
OffsetArrays = "1"
OrderedCollections = "1.6"
PosteriorStats = "0.1.1, 0.2"
PosteriorStats = "0.3"
Random = "1"
StatsBase = "0.33.7, 0.34"
Tables = "1.11.0"
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
PosteriorStats = "7f36be82-ad55-44ba-a5c0-b8b5480d7aa5"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[sources]
InferenceObjects = {path = ".."}

[compat]
Documenter = "1"
DocumenterInterLinks = "1"
67 changes: 30 additions & 37 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,37 @@ doctestfilters = [
r"\s+\"created_at\" => .*", # ignore timestamps in doctests
]

makedocs(;
modules=[
InferenceObjects,
Base.get_extension(InferenceObjects, :InferenceObjectsMCMCDiagnosticToolsExt),
Base.get_extension(InferenceObjects, :InferenceObjectsPosteriorStatsExt),
],
authors="Seth Axen <[email protected]> and contributors",
repo=Remotes.GitHub("arviz-devs", "InferenceObjects.jl"),
sitename="InferenceObjects.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://arviz-devs.github.io/InferenceObjects.jl",
edit_link="main",
assets=String[],
),
pages=[
"Home" => "index.md",
"Dataset" => "dataset.md",
"InferenceData" => "inference_data.md",
"Extensions" => [
"MCMCDiagnosticTools" => "extensions/mcmcdiagnostictools.md",
"PosteriorStats" => "extensions/posteriorstats.md",
# Increase the terminal width from 80 to 100 chars to avoid column truncation
withenv("COLUMNS" => 100) do
makedocs(;
modules=[
InferenceObjects,
Base.get_extension(InferenceObjects, :InferenceObjectsMCMCDiagnosticToolsExt),
Base.get_extension(InferenceObjects, :InferenceObjectsPosteriorStatsExt),
],
],
doctestfilters=doctestfilters,
warnonly=:missing_docs,
plugins=[links],
)
authors="Seth Axen <[email protected]> and contributors",
repo=Remotes.GitHub("arviz-devs", "InferenceObjects.jl"),
sitename="InferenceObjects.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://arviz-devs.github.io/InferenceObjects.jl",
edit_link="main",
assets=String[],
),
pages=[
"Home" => "index.md",
"Dataset" => "dataset.md",
"InferenceData" => "inference_data.md",
"Extensions" => [
"MCMCDiagnosticTools" => "extensions/mcmcdiagnostictools.md",
"PosteriorStats" => "extensions/posteriorstats.md",
],
],
doctestfilters=doctestfilters,
warnonly=:missing_docs,
plugins=[links],
)
end

# run doctests on extensions
function get_extension(mod::Module, name::Symbol)
Expand All @@ -67,16 +70,6 @@ function get_extension(mod::Module, name::Symbol)
end
end

using MCMCDiagnosticTools: MCMCDiagnosticTools
using PosteriorStats: PosteriorStats
for extended_pkg in (MCMCDiagnosticTools, PosteriorStats)
extension_name = Symbol("InferenceObjects", extended_pkg, "Ext")
@info "Running doctests for extension $(extension_name)"
mod = get_extension(InferenceObjects, extension_name)
DocMeta.setdocmeta!(mod, :DocTestSetup, :(using $(Symbol(extended_pkg))))
doctest(mod; manual=false)
end

deploydocs(;
repo="github.com/arviz-devs/InferenceObjects.jl", devbranch="main", push_preview=true
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ using InferenceObjects: InferenceObjects
using PosteriorStats: PosteriorStats
using StatsBase: StatsBase

import PosteriorStats: hdi, loo, loo_pit, r2_score, summarize, waic
import PosteriorStats: eti, hdi, loo, loo_pit, r2_score, summarize, waic
import StatsBase: summarystats

export hdi, loo, loo_pit, r2_score, summarize, waic, summarystats
export eti, hdi, loo, loo_pit, r2_score, summarize, waic, summarystats

maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map

include("utils.jl")
include("hdi.jl")
include("ci.jl")
include("loo.jl")
include("waic.jl")
include("loo_pit.jl")
Expand Down
27 changes: 27 additions & 0 deletions ext/InferenceObjectsPosteriorStatsExt/ci.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

for (ci_fun, ci_desc) in
(:eti => "equal-tailed interval (ETI)", :hdi => "highest density interval (HDI)")
@eval begin
# this pattern ensures that the type is completely specified at compile time
@doc """
$($ci_fun)(data::InferenceData; kwargs...) -> Dataset
$($ci_fun)(data::Dataset; kwargs...) -> Dataset

Calculate the $($ci_desc) for each parameter in the data.

For more details and a description of the `kwargs`, see
[`PosteriorStats.$($ci_fun)`](@extref).
"""
function PosteriorStats.$(ci_fun)(data::InferenceObjects.InferenceData; kwargs...)
return PosteriorStats.$(ci_fun)(data.posterior; kwargs...)
end
function PosteriorStats.$(ci_fun)(data::InferenceObjects.Dataset; kwargs...)
ds = maplayers(data) do var
return _as_dimarray(
PosteriorStats.$(ci_fun)(_params_array(var); kwargs...), var
)
end
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())
end
end
end
37 changes: 0 additions & 37 deletions ext/InferenceObjectsPosteriorStatsExt/hdi.jl

This file was deleted.

4 changes: 2 additions & 2 deletions ext/InferenceObjectsPosteriorStatsExt/loo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ julia> idata = load_example_data("centered_eight");
julia> loo(idata)
PSISLOOResult with estimates
elpd elpd_mcse p p_mcse
-31 1.4 0.9 0.33
elpd se_elpd p se_p
-31 1.4 0.9 0.33
and PSISResult with 500 draws, 4 chains, and 8 parameters
Pareto shape (k) diagnostic values:
Expand Down
27 changes: 13 additions & 14 deletions ext/InferenceObjectsPosteriorStatsExt/summarize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ function StatsBase.summarystats(data::InferenceObjects.Dataset; kwargs...)
end

@doc """
summarize(data::InferenceData, group=:posterior, stats_funs...; kwargs...)
summarize(data::InferenceData, stats_funs...; group=:posterior, kwargs...)
summarize(data::Dataset, stats_funs...; kwargs...)

Compute summary statistics for the data using the provided functions.

For verbose variable labels, provide `compat_labels=false`. For details on `stats_funs` and
For verbose variable labels, provide `compact_labels=false`. For details on `stats_funs` and
`kwargs`, see [`PosteriorStats.summarize`](@extref).

# Examples
Expand All @@ -33,18 +33,17 @@ julia> data = load_example_data("centered_eight");

julia> summarize(data)
SummaryStats
mean std hdi_3% hdi_97% mcse_mean mcse_std ess ⋯
mu 4.2 3.3 -1.61 10.3 0.21 0.088 ⋯
theta[Choate] 6.4 5.9 -3.68 17.9 0.25 0.20 ⋯
theta[Deerfield] 5.0 4.9 -4.98 13.4 0.21 0.15 ⋯
theta[Phillips Andover] 3.4 5.4 -7.54 12.9 0.23 0.17 ⋯
theta[Phillips Exeter] 4.8 5.2 -5.11 14.1 0.21 0.21 ⋯
theta[Hotchkiss] 3.5 4.8 -6.12 12.0 0.25 0.15 ⋯
theta[Lawrenceville] 3.7 5.2 -6.50 12.7 0.22 0.21 ⋯
theta[St. Paul's] 6.5 5.2 -2.67 16.9 0.22 0.15 ⋯
theta[Mt. Hermon] 4.8 5.7 -5.97 15.4 0.24 0.23 ⋯
tau 4.3 3.0 0.715 9.41 0.22 0.14 ⋯
3 columns omitted
mean std eti94 ess_tail ess_bulk rhat mcse_mean mcse_std
mu 4.2 3.3 -2.11 .. 9.90 622 241 1.03 0.21 0.088
theta[Choate] 6.4 5.9 -3.05 .. 19.1 937 572 1.01 0.25 0.20
theta[Deerfield] 5.0 4.9 -4.49 .. 14.2 1214 532 1.01 0.21 0.15
theta[Phillips Andover] 3.4 5.4 -8.17 .. 12.7 1017 511 1.01 0.23 0.17
theta[Phillips Exeter] 4.8 5.2 -4.84 .. 14.5 911 572 1.01 0.21 0.21
theta[Hotchkiss] 3.5 4.8 -6.11 .. 12.0 789 347 1.02 0.25 0.15
theta[Lawrenceville] 3.7 5.2 -6.62 .. 12.6 957 506 1.01 0.22 0.21
theta[St. Paul's] 6.5 5.2 -2.38 .. 18.3 1031 528 1.01 0.22 0.15
theta[Mt. Hermon] 4.8 5.7 -5.52 .. 16.0 1045 538 1.01 0.24 0.23
tau 4.3 3.0 1.06 .. 11.5 214 128 1.03 0.22 0.14
```

Compute the mean, standard deviation, median, and median absolute deviation of the `theta`
Expand Down
11 changes: 10 additions & 1 deletion ext/InferenceObjectsPosteriorStatsExt/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,17 @@ function observations_and_predictions(
end
end

# reshape to (ndraws, nchains, nparams...) and drop the dimensions
function _params_array(data)
sample_dims = Dimensions.dims(data, InferenceObjects.DEFAULT_SAMPLE_DIMS)
param_dims = Dimensions.otherdims(data, sample_dims)
dims_combined = Dimensions.combinedims((sample_dims..., param_dims...))
Dimensions.dimsmatch(Dimensions.dims(data), dims_combined) && return data
return PermutedDimsArray(data, dims_combined)
end

_as_dimarray(x::DimensionalData.AbstractDimArray, ::DimensionalData.AbstractDimArray) = x
function _as_dimarray(x::Union{Real,Missing}, arr::DimensionalData.AbstractDimArray)
function _as_dimarray(x, arr::DimensionalData.AbstractDimArray)
return Dimensions.rebuild(arr, fill(x), ())
end

Expand Down
4 changes: 2 additions & 2 deletions ext/InferenceObjectsPosteriorStatsExt/waic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ julia> idata = load_example_data("centered_eight");
julia> waic(idata)
WAICResult with estimates
elpd elpd_mcse p p_mcse
-31 1.4 0.9 0.32
elpd se_elpd p se_p
-31 1.4 0.9 0.32
```
"""
function PosteriorStats.waic(
Expand Down
Loading
Loading