Skip to content

Commit 2da7cfd

Browse files
Add function loo_score() for CRPS and SCRPS using PWM identity (#196)
* feat: Add loo_score function * refactor: add input extraction into main function * docs: fix references and add PWM identity * docs: fix order in docs * fix: update args and fix floats
1 parent 604c79d commit 2da7cfd

File tree

6 files changed

+416
-0
lines changed

6 files changed

+416
-0
lines changed

docs/source/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ you should jump to {ref}`array_stats_api` and read forward.
5151
arviz_stats.kl_divergence
5252
arviz_stats.loo_expectations
5353
arviz_stats.loo_metrics
54+
arviz_stats.loo_score
5455
arviz_stats.metrics
5556
arviz_stats.mode
5657
arviz_stats.qds

src/arviz_stats/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
loo_i,
1010
loo_expectations,
1111
loo_metrics,
12+
loo_score,
1213
loo_pit,
1314
loo_approximate_posterior,
1415
loo_subsample,

src/arviz_stats/loo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from arviz_stats.loo.loo import loo, loo_i
44
from arviz_stats.loo.loo_approximate_posterior import loo_approximate_posterior
55
from arviz_stats.loo.loo_expectations import loo_expectations, loo_metrics
6+
from arviz_stats.loo.loo_score import loo_score
67
from arviz_stats.loo.loo_pit import loo_pit
78
from arviz_stats.loo.loo_subsample import loo_subsample, update_subsample
89
from arviz_stats.loo.loo_moment_match import loo_moment_match
@@ -17,6 +18,7 @@
1718
"loo_approximate_posterior",
1819
"loo_expectations",
1920
"loo_metrics",
21+
"loo_score",
2022
"loo_pit",
2123
"loo_subsample",
2224
"update_subsample",

src/arviz_stats/loo/helper_loo.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"_select_obs_by_indices",
3636
"_select_obs_by_coords",
3737
"_prepare_full_arrays",
38+
"_validate_crps_input",
3839
]
3940

4041
LooInputs = namedtuple(
@@ -998,3 +999,59 @@ def _reconstruct_upars(upars_new_values, props):
998999
dims=props["dims"],
9991000
coords=props["coords"],
10001001
)
1002+
1003+
1004+
def _has_nan_slice(da, dim):
1005+
"""Check if DataArray has NaN slices along a dimension."""
1006+
other = tuple(dd for dd in da.dims if dd != dim)
1007+
return bool(da.isnull().all(dim=other).any()) if other else bool(da.isnull().any())
1008+
1009+
1010+
def _validate_crps_input(y_pred, y_obs, log_likelihood, *, sample_dims, obs_dims):
1011+
"""Shape and dimension checks."""
1012+
missing_sample = [d for d in sample_dims if d not in y_pred.dims]
1013+
if missing_sample:
1014+
raise ValueError(f"y_pred must include sample dimension '{missing_sample[0]}'")
1015+
1016+
missing_obs = [
1017+
d
1018+
for d in obs_dims
1019+
if (d not in y_pred.dims or d not in y_obs.dims or d not in log_likelihood.dims)
1020+
]
1021+
if missing_obs:
1022+
raise ValueError(f"Missing observation dimension '{missing_obs[0]}' in inputs")
1023+
1024+
ypred_obs_aligned, yobs_aligned = xr.align(y_pred, y_obs, join="inner")
1025+
obs_size_mismatch = [
1026+
d
1027+
for d in obs_dims
1028+
if (
1029+
ypred_obs_aligned.sizes[d] != y_pred.sizes[d] or yobs_aligned.sizes[d] != y_obs.sizes[d]
1030+
)
1031+
]
1032+
if obs_size_mismatch:
1033+
raise ValueError(f"Size mismatch in observation dim '{obs_size_mismatch[0]}'")
1034+
1035+
nan_padded_obs = [d for d in obs_dims if _has_nan_slice(y_obs, d)]
1036+
if nan_padded_obs:
1037+
raise ValueError(f"Size mismatch in observation dim '{nan_padded_obs[0]}'")
1038+
1039+
ypred_ll_aligned, ll_aligned = xr.align(y_pred, log_likelihood, join="inner")
1040+
dims_to_check = (*sample_dims, *obs_dims)
1041+
ll_size_mismatch = [
1042+
d
1043+
for d in dims_to_check
1044+
if (
1045+
ypred_ll_aligned.sizes[d] != y_pred.sizes[d]
1046+
or ll_aligned.sizes[d] != log_likelihood.sizes[d]
1047+
)
1048+
]
1049+
if ll_size_mismatch:
1050+
d0 = ll_size_mismatch[0]
1051+
if d0 in sample_dims:
1052+
raise ValueError(f"Size mismatch in sample dimension '{d0}'")
1053+
raise ValueError(f"Size mismatch in observation dim '{d0}'")
1054+
1055+
nan_padded_sample = [d for d in sample_dims if _has_nan_slice(log_likelihood, d)]
1056+
if nan_padded_sample:
1057+
raise ValueError(f"Size mismatch in sample dimension '{nan_padded_sample[0]}'")

0 commit comments

Comments
 (0)