Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ refer you to other pages for the full argument or algorithm descriptions.

arviz_stats.base.array_stats.loo
arviz_stats.base.array_stats.loo_approximate_posterior
arviz_stats.base.array_stats.loo_expectation
arviz_stats.base.array_stats.loo_mixture
arviz_stats.base.array_stats.loo_pit
arviz_stats.base.array_stats.loo_quantile
arviz_stats.base.array_stats.loo_r2
arviz_stats.base.array_stats.loo_score
arviz_stats.base.array_stats.loo_summary
```
Expand Down
105 changes: 102 additions & 3 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def pareto_khat(self, sample_dims=None, **kwargs):
"""Compute Pareto k-hat diagnostic."""
return self._apply("pareto_khat", sample_dims=sample_dims, **kwargs)

def loo(self, sample_dims=None, reff=1.0, log_weights=None, pareto_k=None, log_jacobian=None):
def loo(self, sample_dims=None, r_eff=1.0, log_weights=None, pareto_k=None, log_jacobian=None):
"""Compute PSIS-LOO-CV."""
return self._apply(
"loo",
sample_dims=sample_dims,
reff=reff,
r_eff=r_eff,
log_weights=log_weights,
pareto_k=pareto_k,
log_jacobian=log_jacobian,
Expand All @@ -184,13 +184,94 @@ def loo_approximate_posterior(self, log_p, log_q, sample_dims=None, log_jacobian
log_jacobian=log_jacobian,
)

def loo_score(self, y_obs, log_weights, kind="crps", sample_dims=None, **kwargs):
def loo_score(
self,
y_obs,
log_ratios=None,
kind="crps",
r_eff=1.0,
log_weights=None,
pareto_k=None,
sample_dims=None,
**kwargs,
):
"""Compute CRPS or SCRPS with PSIS-LOO-CV weights."""
return self._apply(
"loo_score",
y_obs=y_obs,
log_ratios=log_ratios,
kind=kind,
r_eff=r_eff,
log_weights=log_weights,
pareto_k=pareto_k,
sample_dims=sample_dims,
**kwargs,
)

def loo_pit(
self,
y_obs,
log_ratios=None,
r_eff=1.0,
log_weights=None,
pareto_k=None,
sample_dims=None,
random_state=None,
**kwargs,
):
"""Compute LOO-PIT values with PSIS-LOO-CV weights."""
return self._apply(
"loo_pit",
y_obs=y_obs,
log_ratios=log_ratios,
r_eff=r_eff,
log_weights=log_weights,
pareto_k=pareto_k,
sample_dims=sample_dims,
random_state=random_state,
**kwargs,
)

def loo_expectation(
self,
log_ratios=None,
kind="mean",
r_eff=1.0,
log_weights=None,
pareto_k=None,
sample_dims=None,
**kwargs,
):
"""Compute weighted expectation with PSIS-LOO-CV weights."""
return self._apply(
"loo_expectation",
log_ratios=log_ratios,
kind=kind,
r_eff=r_eff,
log_weights=log_weights,
pareto_k=pareto_k,
sample_dims=sample_dims,
**kwargs,
)

def loo_quantile(
self,
log_ratios=None,
probs=None,
r_eff=1.0,
log_weights=None,
pareto_k=None,
sample_dims=None,
**kwargs,
):
"""Compute weighted quantile with PSIS-LOO-CV weights."""
return self._apply(
"loo_quantile",
log_ratios=log_ratios,
probs=probs,
r_eff=r_eff,
log_weights=log_weights,
pareto_k=pareto_k,
sample_dims=sample_dims,
**kwargs,
)
Expand All @@ -199,6 +280,24 @@ def loo_summary(self, p_loo_i):
"""Aggregate pointwise LOO values."""
return self._apply("loo_summary", p_loo_i=p_loo_i)

def loo_r2(
self,
ypred_loo,
n_simulations=4000,
circular=False,
random_state=42,
**kwargs,
):
"""Compute LOO-adjusted :math:`R^2` using Dirichlet-weighted bootstrap."""
return self._apply(
"loo_r2",
ypred_loo=ypred_loo,
n_simulations=n_simulations,
circular=circular,
random_state=random_state,
**kwargs,
)

def power_scale_lw(self, dim=None, **kwargs):
"""Compute log weights for power-scaling of the DataTree."""
return self._apply("power_scale_lw", dim=dim, **kwargs)
Expand Down
Loading