diff --git a/docs/source/api/index.md b/docs/source/api/index.md index e851a8b..07233d5 100644 --- a/docs/source/api/index.md +++ b/docs/source/api/index.md @@ -200,7 +200,11 @@ refer you to other pages for the full argument or algorithm descriptions. arviz_stats.base.array_stats.loo arviz_stats.base.array_stats.loo_approximate_posterior + arviz_stats.base.array_stats.loo_expectation arviz_stats.base.array_stats.loo_mixture + arviz_stats.base.array_stats.loo_pit + arviz_stats.base.array_stats.loo_quantile + arviz_stats.base.array_stats.loo_r2 arviz_stats.base.array_stats.loo_score arviz_stats.base.array_stats.loo_summary ``` diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index 55f0c48..5a87700 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -155,12 +155,12 @@ def pareto_khat(self, sample_dims=None, **kwargs): """Compute Pareto k-hat diagnostic.""" return self._apply("pareto_khat", sample_dims=sample_dims, **kwargs) - def loo(self, sample_dims=None, reff=1.0, log_weights=None, pareto_k=None, log_jacobian=None): + def loo(self, sample_dims=None, r_eff=1.0, log_weights=None, pareto_k=None, log_jacobian=None): """Compute PSIS-LOO-CV.""" return self._apply( "loo", sample_dims=sample_dims, - reff=reff, + r_eff=r_eff, log_weights=log_weights, pareto_k=pareto_k, log_jacobian=log_jacobian, @@ -184,13 +184,94 @@ def loo_approximate_posterior(self, log_p, log_q, sample_dims=None, log_jacobian log_jacobian=log_jacobian, ) - def loo_score(self, y_obs, log_weights, kind="crps", sample_dims=None, **kwargs): + def loo_score( + self, + y_obs, + log_ratios=None, + kind="crps", + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + **kwargs, + ): """Compute CRPS or SCRPS with PSIS-LOO-CV weights.""" return self._apply( "loo_score", y_obs=y_obs, + log_ratios=log_ratios, + kind=kind, + r_eff=r_eff, + log_weights=log_weights, + pareto_k=pareto_k, + sample_dims=sample_dims, + **kwargs, + ) + + def loo_pit( + self, + y_obs, + log_ratios=None, + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + random_state=None, + **kwargs, + ): + """Compute LOO-PIT values with PSIS-LOO-CV weights.""" + return self._apply( + "loo_pit", + y_obs=y_obs, + log_ratios=log_ratios, + r_eff=r_eff, log_weights=log_weights, + pareto_k=pareto_k, + sample_dims=sample_dims, + random_state=random_state, + **kwargs, + ) + + def loo_expectation( + self, + log_ratios=None, + kind="mean", + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + **kwargs, + ): + """Compute weighted expectation with PSIS-LOO-CV weights.""" + return self._apply( + "loo_expectation", + log_ratios=log_ratios, kind=kind, + r_eff=r_eff, + log_weights=log_weights, + pareto_k=pareto_k, + sample_dims=sample_dims, + **kwargs, + ) + + def loo_quantile( + self, + log_ratios=None, + probs=None, + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + **kwargs, + ): + """Compute weighted quantile with PSIS-LOO-CV weights.""" + return self._apply( + "loo_quantile", + log_ratios=log_ratios, + probs=probs, + r_eff=r_eff, + log_weights=log_weights, + pareto_k=pareto_k, sample_dims=sample_dims, **kwargs, ) @@ -199,6 +280,24 @@ def loo_summary(self, p_loo_i): """Aggregate pointwise LOO values.""" return self._apply("loo_summary", p_loo_i=p_loo_i) + def loo_r2( + self, + ypred_loo, + n_simulations=4000, + circular=False, + random_state=42, + **kwargs, + ): + """Compute LOO-adjusted :math:`R^2` using Dirichlet-weighted bootstrap.""" + return self._apply( + "loo_r2", + ypred_loo=ypred_loo, + n_simulations=n_simulations, + circular=circular, + random_state=random_state, + **kwargs, + ) + def power_scale_lw(self, dim=None, **kwargs): """Compute log weights for power-scaling of the DataTree.""" return self._apply("power_scale_lw", dim=dim, **kwargs) diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 9d6aa75..de289d0 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -856,7 +856,7 @@ def loo( ary, chain_axis=-2, draw_axis=-1, - reff=1.0, + r_eff=1.0, log_weights=None, pareto_k=None, log_jacobian=None, @@ -866,19 +866,28 @@ def loo( Parameters ---------- ary : array-like + Log-likelihood values. chain_axis : int, default -2 + Axis for chains. draw_axis : int, default -1 - reff : float, default 1.0 + Axis for draws. + r_eff : float, default 1.0 + Relative effective sample size. log_weights : array-like, optional + Pre-computed PSIS log weights. pareto_k : array-like, optional + Pre-computed Pareto k-hat diagnostic values. log_jacobian : array-like, optional - Log-Jacobian adjustment for variable transformations + Log-Jacobian adjustment for variable transformations. Returns ------- elpd_i : array-like + Pointwise expected log predictive density. pareto_k : array-like + Pareto k-hat diagnostic values. p_loo_i : array-like + Pointwise effective number of parameters. """ ary, chain_axis, draw_axis = process_chain_none(ary, chain_axis, draw_axis) ary, axes = process_ary_axes(ary, [chain_axis, draw_axis]) @@ -890,7 +899,7 @@ def loo( loo_ufunc = make_ufunc(self._loo, n_output=3, n_input=1, n_dims=len(axes)) return loo_ufunc( - ary, r_eff=reff, log_weights=log_weights, pareto_k=pareto_k, log_jacobian=log_jacobian + ary, r_eff=r_eff, log_weights=log_weights, pareto_k=pareto_k, log_jacobian=log_jacobian ) def loo_approximate_posterior( @@ -907,18 +916,26 @@ def loo_approximate_posterior( Parameters ---------- ary : array-like + Log-likelihood values. log_p : array-like + Target log-density values. log_q : array-like + Proposal log-density values. chain_axis : int, default -2 + Axis for chains. draw_axis : int, default -1 + Axis for draws. log_jacobian : float, optional - Log-Jacobian adjustment for variable transformations + Log-Jacobian adjustment for variable transformations. Returns ------- elpd_i : array-like + Pointwise expected log predictive density. pareto_k : array-like + Pareto k-hat diagnostic values. p_loo_i : array-like + Pointwise effective number of parameters. """ ary, log_p, log_q, chain_axis, draw_axis = process_chain_none_multi( ary, log_p, log_q, chain_axis=chain_axis, draw_axis=draw_axis @@ -946,22 +963,24 @@ def loo_mixture( Parameters ---------- ary : array-like - Full log-likelihood array + Full log-likelihood array. obs_axes : tuple of int - Axes corresponding to observation dimensions + Axes corresponding to observation dimensions. chain_axis : int, default -2 + Axis for chains. draw_axis : int, default -1 + Axis for draws. log_jacobian : array-like, optional - Log-Jacobian adjustment for variable transformations + Log-Jacobian adjustment for variable transformations. Returns ------- elpd_i : array-like - Pointwise expected log predictive density + Pointwise expected log predictive density. p_loo_i : array-like - Pointwise effective number of parameters + Pointwise effective number of parameters. mix_log_weights : array-like - Mixture log weights + Mixture log weights. """ ary, chain_axis, draw_axis = process_chain_none(ary, chain_axis, draw_axis) ndim = ary.ndim @@ -970,7 +989,7 @@ def loo_mixture( obs_axes = tuple(ax % ndim for ax in obs_axes) sample_axes = (chain_axis, draw_axis) - return self._loo_mixture( + return self._mixture( ary, obs_axes=obs_axes, sample_axes=sample_axes, log_jacobian=log_jacobian ) @@ -988,55 +1007,208 @@ def loo_score( Parameters ---------- ary : array-like - Posterior predictive samples + Posterior predictive samples. y_obs : array-like - Observed values + Observed values. log_weights : array-like - PSIS-LOO log weights + Pre-computed PSIS log weights. kind : str, default "crps" - "crps" or "scrps" + Score type, either "crps" or "scrps". chain_axis : int, default -2 - Axis for chains + Axis for chains. draw_axis : int, default -1 - Axis for draws + Axis for draws. Returns ------- scores : array-like - Score values (negative CRPS or SCRPS for maximization) + Score values (negative orientation for maximization). """ ary, log_weights, chain_axis, draw_axis = process_chain_none_multi( ary, log_weights, chain_axis=chain_axis, draw_axis=draw_axis ) - ary, axes = process_ary_axes(ary, [chain_axis, draw_axis]) log_weights, _ = process_ary_axes(log_weights, [chain_axis, draw_axis]) loo_score_ufunc = make_ufunc(self._loo_score, n_output=1, n_input=3, n_dims=len(axes)) return loo_score_ufunc(ary, y_obs, log_weights, kind) + def loo_pit( + self, + ary, + y_obs, + log_weights, + chain_axis=-2, + draw_axis=-1, + random_state=None, + ): + """Compute LOO-PIT values with PSIS-LOO-CV weights. + + Parameters + ---------- + ary : array-like + Posterior predictive samples. + y_obs : array-like + Observed values. + log_weights : array-like + Pre-computed PSIS log weights. + chain_axis : int, default -2 + Axis for chains. + draw_axis : int, default -1 + Axis for draws. + random_state : int or Generator, optional + Random seed or Generator for tie-breaking. If None, uses seed 214. + + Returns + ------- + pit_values : array-like + LOO-PIT values in [0, 1]. + """ + ary, log_weights, chain_axis, draw_axis = process_chain_none_multi( + ary, log_weights, chain_axis=chain_axis, draw_axis=draw_axis + ) + ary, axes = process_ary_axes(ary, [chain_axis, draw_axis]) + log_weights, _ = process_ary_axes(log_weights, [chain_axis, draw_axis]) + + if random_state is None: + rng = np.random.default_rng(214) + else: + rng = np.random.default_rng(random_state) + + loo_pit_ufunc = make_ufunc(self._loo_pit, n_output=1, n_input=3, n_dims=len(axes)) + return loo_pit_ufunc(ary, y_obs, log_weights, rng=rng) + + def loo_expectation( + self, + ary, + log_weights, + kind="mean", + chain_axis=-2, + draw_axis=-1, + ): + """Compute weighted expectation with PSIS-LOO-CV weights. + + Parameters + ---------- + ary : array-like + Posterior predictive samples. + log_weights : array-like + Pre-computed PSIS log weights. + kind : str, default "mean" + Type of expectation: "mean", "median", "var", "sd", + "circular_mean", "circular_var", "circular_sd". + chain_axis : int, default -2 + Axis for chains. + draw_axis : int, default -1 + Axis for draws. + + Returns + ------- + expectation : array-like + Weighted expectation values. + """ + ary, log_weights, chain_axis, draw_axis = process_chain_none_multi( + ary, log_weights, chain_axis=chain_axis, draw_axis=draw_axis + ) + ary, axes = process_ary_axes(ary, [chain_axis, draw_axis]) + log_weights, _ = process_ary_axes(log_weights, [chain_axis, draw_axis]) + + loo_expectation_ufunc = make_ufunc( + self._loo_expectation, n_output=1, n_input=2, n_dims=len(axes) + ) + return loo_expectation_ufunc(ary, log_weights, kind) + + def loo_quantile( + self, + ary, + log_weights, + prob, + chain_axis=-2, + draw_axis=-1, + ): + """Compute weighted quantile with PSIS-LOO-CV weights. + + Parameters + ---------- + ary : array-like + Posterior predictive samples. + log_weights : array-like + Pre-computed PSIS log weights. + prob : float + Quantile probability in [0, 1]. + chain_axis : int, default -2 + Axis for chains. + draw_axis : int, default -1 + Axis for draws. + + Returns + ------- + quantile : array-like + Weighted quantile values. + """ + ary, log_weights, chain_axis, draw_axis = process_chain_none_multi( + ary, log_weights, chain_axis=chain_axis, draw_axis=draw_axis + ) + ary, axes = process_ary_axes(ary, [chain_axis, draw_axis]) + log_weights, _ = process_ary_axes(log_weights, [chain_axis, draw_axis]) + + loo_quantile_ufunc = make_ufunc(self._loo_quantile, n_output=1, n_input=2, n_dims=len(axes)) + return loo_quantile_ufunc(ary, log_weights, prob) + def loo_summary(self, elpd_i, p_loo_i): """Aggregate pointwise LOO values. Parameters ---------- elpd_i : array-like - Pointwise expected log predictive density + Pointwise expected log predictive density. p_loo_i : array-like - Pointwise effective number of parameters + Pointwise effective number of parameters. Returns ------- elpd : float - Total expected log predictive density + Total expected log predictive density. elpd_se : float - Standard error of elpd + Standard error of elpd. p_loo : float - Total effective number of parameters + Total effective number of parameters. lppd : float - Log pointwise predictive density + Log pointwise predictive density. """ - return self._loo_summary(elpd_i, p_loo_i) + return self._summary(elpd_i, p_loo_i) + + def loo_r2(self, y_obs, ypred_loo, n_simulations=4000, circular=False, random_state=42): + """Compute LOO-adjusted R-squared using Dirichlet-weighted bootstrap. + + Parameters + ---------- + y_obs : array-like + Observed values. + ypred_loo : array-like + LOO predictions (same shape as y_obs). + n_simulations : int, default 4000 + Number of Dirichlet-weighted bootstrap samples. + circular : bool, default False + Whether the variable is circular (angles in radians). + random_state : int, default 42 + Random seed for reproducibility. + + Returns + ------- + loo_r_squared : array-like + R-squared samples with shape (n_simulations,). + """ + y_obs = np.asarray(y_obs).ravel() + ypred_loo = np.asarray(ypred_loo).ravel() + + return self._loo_r2( + y_obs, + ypred_loo, + n_simulations=n_simulations, + circular=circular, + random_state=random_state, + ) array_stats = BaseArray() diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index 08423d6..d95c68b 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -421,35 +421,35 @@ def pareto_khat(self, da, sample_dims=None, r_eff=None, tail="both", log_weights ).rename("pareto_k") def loo( - self, da, sample_dims=None, reff=1.0, log_weights=None, pareto_k=None, log_jacobian=None + self, da, sample_dims=None, r_eff=1.0, log_weights=None, pareto_k=None, log_jacobian=None ): """Compute PSIS-LOO-CV. Parameters ---------- da : DataArray - Log-likelihood values with shape (chain, draw, *obs_dims) + Log-likelihood values. sample_dims : list of str, optional - Sample dimensions. Defaults to ["chain", "draw"] - reff : float, default 1.0 - Relative effective sample size + Sample dimensions. Defaults to ["chain", "draw"]. + r_eff : float, default 1.0 + Relative effective sample size. log_weights : DataArray, optional - Pre-computed PSIS log weights (same shape as da) + Pre-computed PSIS log weights. pareto_k : DataArray, optional - Pre-computed Pareto k values (shape: obs_dims only) + Pre-computed Pareto k-hat diagnostic values. log_jacobian : DataArray, optional - Log-Jacobian adjustment (shape: obs_dims only) + Log-Jacobian adjustment for variable transformations. Returns ------- tuple of (elpd_i, pareto_k, p_loo_i) : DataArrays - Pointwise LOO values for each observation + Pointwise LOO values for each observation. """ dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) kwargs = { "chain_axis": chain_axis, "draw_axis": draw_axis, - "reff": reff, + "r_eff": r_eff, } if log_weights is not None and pareto_k is not None: @@ -503,20 +503,20 @@ def loo_approximate_posterior(self, da, log_p, log_q, sample_dims=None, log_jaco Parameters ---------- da : DataArray - Log-likelihood values + Log-likelihood values. log_p : DataArray - Target log-density values (chain, draw) + Target log-density values. log_q : DataArray - Proposal log-density values (chain, draw) + Proposal log-density values. sample_dims : list of str, optional - Sample dimensions. Defaults to ["chain", "draw"] + Sample dimensions. Defaults to ["chain", "draw"]. log_jacobian : DataArray, optional - Log-Jacobian adjustment for variable transformations + Log-Jacobian adjustment for variable transformations. Returns ------- tuple of (elpd_i, pareto_k, p_loo_i) : DataArrays - Pointwise LOO values for each observation + Pointwise LOO values for each observation. """ dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) kwargs = { @@ -561,20 +561,20 @@ def loo_mixture(self, da, sample_dims=None, log_jacobian=None): Parameters ---------- da : DataArray - Log-likelihood values + Log-likelihood values. sample_dims : list of str, optional - Sample dimensions. Defaults to ["chain", "draw"] + Sample dimensions. Defaults to ["chain", "draw"]. log_jacobian : DataArray, optional - Log-Jacobian adjustment for variable transformations + Log-Jacobian adjustment for variable transformations. Returns ------- elpd_i : DataArray - Pointwise expected log predictive density + Pointwise expected log predictive density. p_loo_i : DataArray - Pointwise effective number of parameters + Pointwise effective number of parameters. mix_log_weights : DataArray - Mixture log weights + Mixture log weights. """ dims, _, _ = validate_dims_chain_draw_axis(sample_dims) obs_dims = [d for d in da.dims if d not in dims] @@ -599,29 +599,53 @@ def loo_mixture(self, da, sample_dims=None, log_jacobian=None): return elpd_i, p_loo_i, mix_log_weights - def loo_score(self, da, y_obs, log_weights, kind="crps", sample_dims=None): + def loo_score( + self, + da, + y_obs, + log_ratios=None, + kind="crps", + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + ): """Compute CRPS or SCRPS with PSIS-LOO-CV weights. Parameters ---------- da : DataArray - Posterior predictive samples + Posterior predictive samples. y_obs : DataArray or scalar - Observed values - log_weights : DataArray - PSIS-LOO log weights + Observed values. + log_ratios : DataArray, optional + Log importance ratios (typically -log_likelihood). If provided, + PSIS will be computed internally. kind : str, default "crps" - "crps" or "scrps" + Score type, either "crps" or "scrps". + r_eff : float, default 1.0 + Relative effective sample size. + log_weights : DataArray, optional + Pre-computed PSIS log weights. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. sample_dims : list of str, optional - Sample dimensions. Defaults to ["chain", "draw"] + Sample dimensions. Defaults to ["chain", "draw"]. Returns ------- scores : DataArray - Score values (negative CRPS or SCRPS for maximization) + Score values (negative orientation for maximization). + pareto_k : DataArray + Pareto k-hat diagnostic values. """ dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) - return apply_ufunc( + if log_weights is None: + if log_ratios is None: + raise ValueError("Either log_ratios or log_weights must be provided") + log_weights, pareto_k = self.psislw(-log_ratios, r_eff=r_eff, dim=list(dims)) + + scores = apply_ufunc( self.array_class.loo_score, da, y_obs, @@ -634,6 +658,192 @@ def loo_score(self, da, y_obs, log_weights, kind="crps", sample_dims=None): "draw_axis": draw_axis, }, ) + return scores, pareto_k + + def loo_pit( + self, + da, + y_obs, + log_ratios=None, + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + random_state=None, + ): + """Compute LOO-PIT values on DataArray input. + + Parameters + ---------- + da : DataArray + Posterior predictive samples. + y_obs : DataArray + Observed values. + log_ratios : DataArray, optional + Log importance ratios (typically -log_likelihood). If provided, + PSIS will be computed internally. + r_eff : float, default 1.0 + Relative effective sample size. + log_weights : DataArray, optional + Pre-computed PSIS log weights. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. + sample_dims : list of str, optional + Sample dimensions. Defaults to ["chain", "draw"]. + random_state : int or Generator, optional + Random seed or Generator for tie-breaking. If None, uses seed 214. + + Returns + ------- + pit_values : DataArray + LOO-PIT values in [0, 1]. + pareto_k : DataArray + Pareto k-hat diagnostic values. + """ + dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) + if log_weights is None: + if log_ratios is None: + raise ValueError("Either log_ratios or log_weights must be provided") + log_weights, pareto_k = self.psislw(-log_ratios, r_eff=r_eff, dim=list(dims)) + + pit_values = apply_ufunc( + self.array_class.loo_pit, + da, + y_obs, + log_weights, + input_core_dims=[dims, [], dims], + output_core_dims=[[]], + kwargs={ + "chain_axis": chain_axis, + "draw_axis": draw_axis, + "random_state": random_state, + }, + ) + return pit_values, pareto_k + + def loo_expectation( + self, + da, + log_ratios=None, + kind="mean", + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + ): + """Compute weighted expectation on DataArray input. + + Parameters + ---------- + da : DataArray + Posterior predictive samples. + log_ratios : DataArray, optional + Log importance ratios (typically -log_likelihood). If provided, + PSIS will be computed internally. + kind : str, default "mean" + Type of expectation: "mean", "median", "var", "sd", + "circular_mean", "circular_var", "circular_sd". + r_eff : float, default 1.0 + Relative effective sample size. + log_weights : DataArray, optional + Pre-computed PSIS log weights. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. + sample_dims : list of str, optional + Sample dimensions. Defaults to ["chain", "draw"]. + + Returns + ------- + expectation : DataArray + Weighted expectation values. + pareto_k : DataArray + Pareto k-hat diagnostic values. + """ + dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) + if log_weights is None: + if log_ratios is None: + raise ValueError("Either log_ratios or log_weights must be provided") + log_weights, pareto_k = self.psislw(-log_ratios, r_eff=r_eff, dim=list(dims)) + + expectation = apply_ufunc( + self.array_class.loo_expectation, + da, + log_weights, + input_core_dims=[dims, dims], + output_core_dims=[[]], + kwargs={ + "kind": kind, + "chain_axis": chain_axis, + "draw_axis": draw_axis, + }, + ) + return expectation, pareto_k + + def loo_quantile( + self, + da, + log_ratios=None, + probs=None, + r_eff=1.0, + log_weights=None, + pareto_k=None, + sample_dims=None, + ): + """Compute weighted quantile on DataArray input. + + Parameters + ---------- + da : DataArray + Posterior predictive samples. + log_ratios : DataArray, optional + Log importance ratios (typically -log_likelihood). If provided, + PSIS will be computed internally. + probs : float or array-like + Quantile probability(ies) in [0, 1]. + r_eff : float, default 1.0 + Relative effective sample size. + log_weights : DataArray, optional + Pre-computed PSIS log weights. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. + sample_dims : list of str, optional + Sample dimensions. Defaults to ["chain", "draw"]. + + Returns + ------- + quantile : DataArray + Weighted quantile values. + pareto_k : DataArray + Pareto k-hat diagnostic values. + """ + dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(sample_dims) + if log_weights is None: + if log_ratios is None: + raise ValueError("Either log_ratios or log_weights must be provided") + log_weights, pareto_k = self.psislw(-log_ratios, r_eff=r_eff, dim=list(dims)) + + probs = np.atleast_1d(probs) + quantile_results = [] + + for prob in probs: + quantile = apply_ufunc( + self.array_class.loo_quantile, + da, + log_weights, + input_core_dims=[dims, dims], + output_core_dims=[[]], + kwargs={ + "prob": float(prob), + "chain_axis": chain_axis, + "draw_axis": draw_axis, + }, + ) + quantile_results.append(quantile) + + if len(probs) == 1: + return quantile_results[0], pareto_k + + return concat(quantile_results, dim="quantile").assign_coords(quantile=probs), pareto_k def loo_summary(self, da, p_loo_i): """Aggregate pointwise LOO values. @@ -641,23 +851,62 @@ def loo_summary(self, da, p_loo_i): Parameters ---------- da : DataArray - Pointwise expected log predictive density values (elpd_i) + Pointwise expected log predictive density (elpd_i). p_loo_i : DataArray - Pointwise effective number of parameters + Pointwise effective number of parameters. Returns ------- elpd : float - Total expected log predictive density + Total expected log predictive density. se : float - Standard error of elpd + Standard error of elpd. p_loo : float - Total effective number of parameters + Total effective number of parameters. lppd : float - Log pointwise predictive density + Log pointwise predictive density. """ return self.array_class.loo_summary(da.values, p_loo_i.values) + def loo_r2( + self, + da, + ypred_loo, + n_simulations=4000, + circular=False, + random_state=42, + ): + """Compute LOO-adjusted :math:`R^2` using Dirichlet-weighted bootstrap. + + Parameters + ---------- + da : DataArray + Observed values (passed via accessor as self._obj). + ypred_loo : DataArray + LOO predictions (same shape as da). + n_simulations : int, default 4000 + Number of Dirichlet-weighted bootstrap samples. + circular : bool, default False + Whether the variable is circular (angles in radians). + random_state : int, default 42 + Random seed for reproducibility. + + Returns + ------- + loo_r_squared : ndarray + Array of :math:`R^2` samples with shape (n_simulations,). + """ + y_obs_vals = np.asarray(da) + ypred_loo_vals = np.asarray(ypred_loo) + + return self.array_class.loo_r2( + y_obs_vals, + ypred_loo_vals, + n_simulations=n_simulations, + circular=circular, + random_state=random_state, + ) + def power_scale_lw(self, da, alpha=0, dim=None): """Compute log weights for power-scaling component by alpha.""" dims = validate_dims(dim) diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index f9ea017..e2db77f 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -8,7 +8,7 @@ from scipy import stats from scipy.special import logsumexp -from arviz_stats.base.circular_utils import circular_diff, circular_var +from arviz_stats.base.circular_utils import circular_diff, circular_mean, circular_sd, circular_var from arviz_stats.base.core import _CoreBase from arviz_stats.base.stats_utils import not_valid as _not_valid @@ -461,7 +461,219 @@ def _loo_approximate_posterior(self, ary, log_p, log_q, log_jacobian=None): ) @staticmethod - def _loo_mixture(ary, obs_axes, sample_axes, log_jacobian=None): + def _loo_score(ary, y_obs, log_weights, kind): + """ + Compute CRPS or SCRPS for a single observation. + + Parameters + ---------- + ary : np.ndarray + 1D array of posterior predictive samples (flattened chain*draw). + y_obs : float + Observed value. + log_weights : np.ndarray + 1D array of pre-computed PSIS log weights. + kind : str + Score type, either "crps" or "scrps". + + Returns + ------- + score : float + Score value (negative orientation for maximization). + """ + ary = np.asarray(ary).ravel() + log_weights = np.asarray(log_weights).ravel() + y_obs = np.asarray(y_obs).flat[0] + abs_error = np.abs(ary - y_obs) + + log_den = logsumexp(log_weights) + loo_weighted_abs_error = np.exp(logsumexp(log_weights, b=abs_error) - log_den) + loo_weighted_mean_prediction = np.exp(logsumexp(log_weights, b=ary) - log_den) + + weights = np.exp(log_weights - log_weights.max()) + weights /= np.sum(weights) + + idx = np.argsort(ary, kind="mergesort") + values_sorted = ary[idx] + weights_sorted = weights[idx] + + cumulative_weights = np.cumsum(weights_sorted) + f_minus = cumulative_weights - weights_sorted + f_mid = f_minus + weights_sorted / 2 + pwm_first_moment_b1 = np.sum(weights_sorted * values_sorted * f_mid) + + crps = loo_weighted_abs_error + loo_weighted_mean_prediction - 2.0 * pwm_first_moment_b1 + + if kind == "crps": + return -crps + + bracket = 2.0 * f_minus + weights_sorted - 1.0 + gini_mean_difference = 2.0 * np.sum(weights_sorted * values_sorted * bracket) + return -(loo_weighted_abs_error / gini_mean_difference) - 0.5 * np.log(gini_mean_difference) + + @staticmethod + def _loo_pit(ary, y_obs, log_weights, rng=None): + """ + Compute LOO-PIT value for a single observation. + + Parameters + ---------- + ary : np.ndarray + 1D array of posterior predictive samples (flattened chain*draw). + y_obs : float + Observed value. + log_weights : np.ndarray + 1D array of pre-computed PSIS log weights. + rng : np.random.Generator, optional + Random number generator for tie-breaking. If None, uses midpoint. + + Returns + ------- + pit : float + LOO-PIT value in [0, 1]. + """ + ary = np.asarray(ary).ravel() + log_weights = np.asarray(log_weights).ravel() + y_obs_val = np.asarray(y_obs).ravel()[0] + log_norm = logsumexp(log_weights) + weights = np.exp(log_weights - log_norm) + + sel_below = ary < y_obs_val + pit_lower = np.sum(weights[sel_below]) + + sel_equal = ary == y_obs_val + if np.any(sel_equal): + pit_at_obs = np.sum(weights[sel_equal]) + pit_upper = pit_lower + pit_at_obs + + if rng is not None: + return rng.uniform(pit_lower, pit_upper) + return (pit_lower + pit_upper) / 2.0 + + return pit_lower + + def _loo_expectation(self, ary, log_weights, kind): + """ + Compute weighted expectation for a single observation. + + Parameters + ---------- + ary : np.ndarray + 1D array of posterior predictive samples (flattened chain*draw). + log_weights : np.ndarray + 1D array of pre-computed PSIS log weights. + kind : str + Type of expectation: "mean", "median", "var", "sd", + "circular_mean", "circular_var", "circular_sd". + + Returns + ------- + expectation : float + Weighted expectation value. + """ + ary = np.asarray(ary).ravel() + log_weights = np.asarray(log_weights).ravel() + log_norm = logsumexp(log_weights) + weights = np.exp(log_weights - log_norm) + + if kind == "mean": + result = np.sum(weights * ary) + elif kind == "median": + result = self._weighted_quantile(ary, weights, 0.5) + elif kind in ("var", "sd"): + mean_val = np.sum(weights * ary) + ess = 1.0 / np.sum(weights**2) + correction = ess / (ess - 1) if ess > 1 else 1.0 + var_val = np.sum(weights * (ary - mean_val) ** 2) * correction + result = np.sqrt(var_val) if kind == "sd" else var_val + elif kind in ("circular_mean", "circular_var", "circular_sd"): + angles_2d = ary.reshape(1, -1) + weights_2d = weights.reshape(1, -1) + if kind == "circular_mean": + result = circular_mean(angles_2d, weights_2d)[0] + elif kind == "circular_var": + result = circular_var(angles_2d, weights_2d)[0] + else: # circular_sd + result = circular_sd(angles_2d, weights_2d)[0] + else: + raise ValueError(f"Unknown kind: {kind}") + + return result + + def _loo_quantile(self, ary, log_weights, prob): + """ + Compute weighted quantile for a single observation. + + Parameters + ---------- + ary : np.ndarray + 1D array of posterior predictive samples (flattened chain*draw). + log_weights : np.ndarray + 1D array of pre-computed PSIS log weights. + prob : float + Quantile probability in [0, 1]. + + Returns + ------- + quantile : float + Weighted quantile value. + """ + ary = np.asarray(ary).ravel() + log_weights = np.asarray(log_weights).ravel() + log_norm = logsumexp(log_weights) + weights = np.exp(log_weights - log_norm) + + return self._weighted_quantile(ary, weights, prob) + + @staticmethod + def _loo_r2(y_obs, ypred_loo, n_simulations=4000, circular=False, random_state=42): + """ + Compute LOO-adjusted :math:`R^2` using Dirichlet-weighted bootstrap. + + Parameters + ---------- + y_obs : np.ndarray + 1D array of observed values. + ypred_loo : np.ndarray + 1D array of LOO predictions (same length as y_obs). + n_simulations : int, default 4000 + Number of Dirichlet-weighted bootstrap samples. + circular : bool, default False + Whether the variable is circular (angles in radians). + random_state : int, default 42 + Random seed for reproducibility. + + Returns + ------- + loo_r_squared : np.ndarray + Array of :math:`R^2` samples with shape (n_simulations,). + """ + y_obs = np.asarray(y_obs).ravel() + ypred_loo = np.asarray(ypred_loo).ravel() + + if circular: + eloo = circular_diff(ypred_loo, y_obs) + else: + eloo = ypred_loo - y_obs + + n = len(y_obs) + rd = stats.dirichlet.rvs(np.ones(n), size=n_simulations, random_state=random_state) + + if circular: + loo_r_squared = 1 - circular_var(eloo, rd) + else: + vary = (np.sum(rd * y_obs**2, axis=1) - np.sum(rd * y_obs, axis=1) ** 2) * (n / (n - 1)) + vareloo = (np.sum(rd * eloo**2, axis=1) - np.sum(rd * eloo, axis=1) ** 2) * ( + n / (n - 1) + ) + + loo_r_squared = 1 - vareloo / vary + loo_r_squared = np.clip(loo_r_squared, -1, 1) + + return loo_r_squared + + @staticmethod + def _mixture(ary, obs_axes, sample_axes, log_jacobian=None): """ Compute mixture importance sampling LOO (Mix-IS-LOO). @@ -505,59 +717,50 @@ def _loo_mixture(ary, obs_axes, sample_axes, log_jacobian=None): return elpd_i, p_loo_i, mix_log_weights @staticmethod - def _loo_score(ary, y_obs, log_weights, kind): - """ - Compute CRPS or SCRPS for a single observation. + def _weighted_quantile(ary, weights, prob): + """Compute weighted quantile. Parameters ---------- ary : np.ndarray - 2D array (chain, draw) of posterior predictive samples - y_obs : float - Observed value - log_weights : np.ndarray - 2D array (chain, draw) of PSIS-LOO-CV log weights - kind : str - "crps" or "scrps" + 1D array of values + weights : np.ndarray + 1D array of normalized weights (must sum to 1) + prob : float + Quantile probability in [0, 1] Returns ------- - score : float - The score value (negative CRPS or SCRPS for maximization) + quantile : float + Weighted quantile value """ - ary = np.asarray(ary).ravel() - log_weights = np.asarray(log_weights).ravel() - y_obs = np.asarray(y_obs).flat[0] - - abs_error = np.abs(ary - y_obs) - - log_den = logsumexp(log_weights) - loo_weighted_abs_error = np.exp(logsumexp(log_weights, b=abs_error) - log_den) - loo_weighted_mean_prediction = np.exp(logsumexp(log_weights, b=ary) - log_den) + nonzero = weights != 0 + ary = ary[nonzero] + weights = weights[nonzero] + n = ary.size - weights = np.exp(log_weights - log_weights.max()) - weights /= np.sum(weights) + if n == 0: + return np.nan + nw = weights.sum() ** 2 / (weights**2).sum() - idx = np.argsort(ary, kind="mergesort") - values_sorted = ary[idx] - weights_sorted = weights[idx] + idx = np.argsort(ary) + sorted_ary = ary[idx] + sorted_weights = weights[idx] - cumulative_weights = np.cumsum(weights_sorted) - f_minus = cumulative_weights - weights_sorted - f_mid = f_minus + weights_sorted / 2 - pwm_first_moment_b1 = np.sum(weights_sorted * values_sorted * f_mid) + sorted_weights = sorted_weights / sorted_weights.sum() + weights_cum = np.concatenate([[0], np.cumsum(sorted_weights)]) - crps = loo_weighted_abs_error + loo_weighted_mean_prediction - 2.0 * pwm_first_moment_b1 + h = (nw - 1) * prob + 1 + h = np.clip(h, 1, nw) - if kind == "crps": - return -crps + u = np.maximum((h - 1) / nw, np.minimum(h / nw, weights_cum)) + v = u * nw - h + 1 + w = np.diff(v) - bracket = 2.0 * f_minus + weights_sorted - 1.0 - gini_mean_difference = 2.0 * np.sum(weights_sorted * values_sorted * bracket) - return -(loo_weighted_abs_error / gini_mean_difference) - 0.5 * np.log(gini_mean_difference) + return np.sum(sorted_ary * w) @staticmethod - def _loo_summary(elpd_i, p_loo_i): + def _summary(elpd_i, p_loo_i): """ Aggregate pointwise LOO values. diff --git a/src/arviz_stats/loo/helper_loo.py b/src/arviz_stats/loo/helper_loo.py index 56ee770..5024416 100644 --- a/src/arviz_stats/loo/helper_loo.py +++ b/src/arviz_stats/loo/helper_loo.py @@ -125,7 +125,7 @@ def _compute_loo_results( elpd_i, pareto_k, p_loo_i = log_likelihood_da.azstats.loo( sample_dims=sample_dims, - reff=reff, + r_eff=reff, log_weights=log_weights, pareto_k=pareto_k, log_jacobian=log_jacobian, diff --git a/src/arviz_stats/loo/loo_expectations.py b/src/arviz_stats/loo/loo_expectations.py index 8dd7220..7809cc7 100644 --- a/src/arviz_stats/loo/loo_expectations.py +++ b/src/arviz_stats/loo/loo_expectations.py @@ -3,21 +3,20 @@ import numpy as np import xarray as xr from arviz_base import convert_to_datatree, extract, rcParams -from scipy.stats import dirichlet from xarray import apply_ufunc -from arviz_stats.base.circular_utils import circular_diff, circular_mean, circular_sd, circular_var -from arviz_stats.loo.helper_loo import _warn_pareto_k +from arviz_stats.loo.helper_loo import _get_r_eff, _warn_pareto_k from arviz_stats.metrics import _metrics, _summary_r2 -from arviz_stats.utils import ELPDData, get_log_likelihood_dataset +from arviz_stats.utils import get_log_likelihood_dataset def loo_expectations( data, var_name=None, - log_weights=None, kind="mean", probs=None, + log_weights=None, + pareto_k=None, ): """Compute weighted expectations using the PSIS-LOO-CV method. @@ -31,13 +30,6 @@ def loo_expectations( var_name: str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. - log_weights : DataArray or ELPDData, optional - Smoothed log weights. Can be either: - - - A DataArray with the same shape as the log likelihood data - - An ELPDData object from a previous :func:`arviz_stats.loo` call. - - Defaults to None. If not provided, it will be computed using the PSIS-LOO method. kind: str, optional The kind of expectation to compute. Available options are: @@ -51,6 +43,11 @@ def loo_expectations( - 'circular_sd'. probs: float or list of float, optional The quantile(s) to compute when kind is 'quantile'. + log_weights : DataArray, optional + Pre-computed smoothed log weights from PSIS. Must be provided together with pareto_k. + If not provided, PSIS will be computed internally. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. Must be provided together with log_weights. Returns ------- @@ -112,61 +109,54 @@ def loo_expectations( log_likelihood = get_log_likelihood_dataset(data, var_names=var_name) n_samples = log_likelihood[var_name].sizes["chain"] * log_likelihood[var_name].sizes["draw"] - if log_weights is None: - log_weights, _ = log_likelihood.azstats.psislw() - log_weights = log_weights[var_name] - - if isinstance(log_weights, ELPDData): - if log_weights.log_weights is None: - raise ValueError("ELPDData object does not contain log_weights") - log_weights = log_weights.log_weights - if var_name in log_weights: - log_weights = log_weights[var_name] - - weights = np.exp(log_weights) + r_eff = _get_r_eff(data, n_samples) posterior_predictive = extract( data, group="posterior_predictive", var_names=var_name, combined=False ) - weighted_predictions = posterior_predictive.weighted(weights) - - if kind == "mean": - loo_expec = weighted_predictions.mean(dim=dims) - - elif kind == "median": - loo_expec = weighted_predictions.quantile(0.5, dim=dims) - - elif kind == "var": - # We use a Bessel's like correction term - # instead of n/(n-1) we use ESS/(ESS-1) - # where ESS/(ESS-1) = 1/(1-sum(weights**2)) - loo_expec = weighted_predictions.var(dim=dims) / (1 - np.sum(weights**2)) - - elif kind == "sd": - loo_expec = (weighted_predictions.var(dim=dims) / (1 - np.sum(weights**2))) ** 0.5 - - elif kind == "quantile": - loo_expec = weighted_predictions.quantile(probs, dim=dims) - - elif kind == "circular_mean": - weights = weights / weights.sum(dim=dims) - loo_expec = circular_mean(posterior_predictive, weights=weights, dims=dims) - - elif kind == "circular_var": - weights = weights / weights.sum(dim=dims) - loo_expec = circular_var(posterior_predictive, weights=weights, dims=dims) - else: # kind == "circular_sd" - weights = weights / weights.sum(dim=dims) - loo_expec = circular_sd(posterior_predictive, weights=weights, dims=dims) - - log_ratios = -log_likelihood[var_name] + sample_dims = list(dims) + + if log_weights is not None and pareto_k is not None: + if kind in ("mean", "median", "var", "sd", "circular_mean", "circular_var", "circular_sd"): + loo_expec, _ = posterior_predictive.azstats.loo_expectation( + log_weights=log_weights, + pareto_k=pareto_k, + kind=kind, + r_eff=r_eff, + sample_dims=sample_dims, + ) + else: + loo_expec, _ = posterior_predictive.azstats.loo_quantile( + log_weights=log_weights, + pareto_k=pareto_k, + probs=probs, + r_eff=r_eff, + sample_dims=sample_dims, + ) + log_ratios_for_khat = log_weights + else: + log_ratios = -log_likelihood[var_name] + if kind in ("mean", "median", "var", "sd", "circular_mean", "circular_var", "circular_sd"): + loo_expec, _ = posterior_predictive.azstats.loo_expectation( + log_ratios=log_ratios, + kind=kind, + r_eff=r_eff, + sample_dims=sample_dims, + ) + else: + loo_expec, _ = posterior_predictive.azstats.loo_quantile( + log_ratios=log_ratios, + probs=probs, + r_eff=r_eff, + sample_dims=sample_dims, + ) + log_ratios_for_khat = log_ratios - # Compute function-specific khat khat = apply_ufunc( _get_function_khat, posterior_predictive, - log_ratios, + log_ratios_for_khat, input_core_dims=[dims, dims], output_core_dims=[[]], exclude_dims=set(dims), @@ -181,7 +171,7 @@ def loo_expectations( return loo_expec, khat -def loo_metrics(data, kind="rmse", var_name=None, log_weights=None, round_to=None): +def loo_metrics(data, kind="rmse", var_name=None, round_to=None): """Compute predictive metrics using the PSIS-LOO-CV method. Currently supported metrics are mean absolute error, mean squared error and @@ -206,13 +196,6 @@ def loo_metrics(data, kind="rmse", var_name=None, log_weights=None, round_to=Non var_name: str, optional The name of the variable in log_likelihood groups storing the pointwise log likelihood data to use for loo computation. - log_weights: DataArray or ELPDData, optional - Smoothed log weights. Can be either: - - - A DataArray with the same shape as the log likelihood data - - An ELPDData object from a previous :func:`arviz_stats.loo` call. - - Defaults to None. If not provided, it will be computed using the PSIS-LOO method. round_to: int or str or None, optional If integer, number of decimal places to round the result. If string of the form '2g' number of significant digits to round the result. Defaults to '2g'. @@ -258,7 +241,7 @@ def loo_metrics(data, kind="rmse", var_name=None, log_weights=None, round_to=Non var_name = list(data.observed_data.data_vars.keys())[0] observed = data.observed_data[var_name] - predicted, _ = loo_expectations(data, kind="mean", var_name=var_name, log_weights=log_weights) + predicted, _ = loo_expectations(data, kind="mean", var_name=var_name) return _metrics(observed, predicted, kind, round_to) @@ -350,32 +333,21 @@ def loo_r2( if round_to is None: round_to = rcParams["stats.round_to"] - y = data.observed_data[var_name].values + y_obs = data.observed_data[var_name] if circular: kind = "circular_mean" else: kind = "mean" - # Here we should compute the loo-adjusted posterior mean, not the predictive mean - ypred_loo = loo_expectations(data, var_name=var_name, kind=kind)[0].values - - if circular: - eloo = circular_diff(ypred_loo, y) - else: - eloo = ypred_loo - y - - n = len(y) - rd = dirichlet.rvs(np.ones(n), size=n_simulations, random_state=42) - - if circular: - loo_r_squared = 1 - circular_var(eloo, rd) - else: - vary = (np.sum(rd * y**2, axis=1) - np.sum(rd * y, axis=1) ** 2) * (n / (n - 1)) - vareloo = (np.sum(rd * eloo**2, axis=1) - np.sum(rd * eloo, axis=1) ** 2) * (n / (n - 1)) + ypred_loo = loo_expectations(data, var_name=var_name, kind=kind)[0] - loo_r_squared = 1 - vareloo / vary - loo_r_squared = np.clip(loo_r_squared, -1, 1) + loo_r_squared = y_obs.azstats.loo_r2( + ypred_loo=ypred_loo, + n_simulations=n_simulations, + circular=circular, + random_state=42, + ) if summary: return _summary_r2("loo", loo_r_squared, point_estimate, ci_kind, ci_prob, round_to) diff --git a/src/arviz_stats/loo/loo_pit.py b/src/arviz_stats/loo/loo_pit.py index bbb1c7b..0ae542d 100644 --- a/src/arviz_stats/loo/loo_pit.py +++ b/src/arviz_stats/loo/loo_pit.py @@ -1,18 +1,18 @@ """Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.""" -import numpy as np import xarray as xr from arviz_base import convert_to_datatree, extract -from xarray_einstats.stats import logsumexp from arviz_stats.loo.helper_loo import _get_r_eff -from arviz_stats.utils import ELPDData, get_log_likelihood_dataset +from arviz_stats.utils import get_log_likelihood_dataset def loo_pit( data, var_names=None, log_weights=None, + pareto_k=None, + random_state=None, ): r"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values. @@ -30,13 +30,15 @@ def loo_pit( Names of the variables to be used to compute the LOO-PIT values. If None, all variables are used. The function assumes that the observed and log_likelihood variables share the same names. - log_weights: DataArray or ELPDData, optional - Smoothed log weights. Can be either: - - - A DataArray with the same shape as ``y_pred`` - - An ELPDData object from a previous :func:`arviz_stats.loo` call. - - Defaults to None. If not provided, it will be computed using the PSIS-LOO method. + log_weights : Dataset, optional + Pre-computed smoothed log weights from PSIS. Must be a Dataset with variables + matching var_names. Must be provided together with pareto_k. + pareto_k : Dataset, optional + Pre-computed Pareto k-hat diagnostic values. Must be a Dataset with variables + matching var_names. Must be provided together with log_weights. + random_state : int or Generator, optional + Random seed or numpy Generator for tie-breaking randomization in discrete data. + If None, uses seed 214 for reproducibility. Returns ------- @@ -82,7 +84,6 @@ def loo_pit( arXiv preprint https://arxiv.org/abs/1507.02646 """ data = convert_to_datatree(data) - rng = np.random.default_rng(214) if var_names is None: var_names = list(data.observed_data.data_vars.keys()) @@ -90,16 +91,8 @@ def loo_pit( var_names = [var_names] log_likelihood = get_log_likelihood_dataset(data, var_names=var_names) - - if log_weights is None: - n_samples = log_likelihood.chain.size * log_likelihood.draw.size - reff = _get_r_eff(data, n_samples) - log_weights, _ = log_likelihood.azstats.psislw(r_eff=reff) - - if isinstance(log_weights, ELPDData): - if log_weights.log_weights is None: - raise ValueError("ELPDData object does not contain log_weights") - log_weights = log_weights.log_weights + n_samples = log_likelihood.chain.size * log_likelihood.draw.size + r_eff = _get_r_eff(data, n_samples) posterior_predictive = extract( data, @@ -116,32 +109,31 @@ def loo_pit( keep_dataset=True, ) - sel_min = {} - sel_sup = {} + sample_dims = ["chain", "draw"] + loo_pit_values = xr.Dataset(coords=observed_data.coords) + for var in var_names: pred = posterior_predictive[var] obs = observed_data[var] - sel_min[var] = pred < obs - sel_sup[var] = pred == obs - - sel_min = xr.Dataset(sel_min) - sel_sup = xr.Dataset(sel_sup) - pit = np.exp(logsumexp(log_weights.where(sel_min, -np.inf), dims=["chain", "draw"])) - - loo_pit_values = xr.Dataset(coords=observed_data.coords) - for var in var_names: - pit_lower = pit[var].values - - if sel_sup[var].any(): - pit_sup_addition = np.exp( - logsumexp(log_weights.where(sel_sup[var], -np.inf), dims=["chain", "draw"]) + if log_weights is not None and pareto_k is not None: + pit_values, _ = pred.azstats.loo_pit( + y_obs=obs, + log_weights=log_weights[var], + pareto_k=pareto_k[var], + r_eff=r_eff, + sample_dims=sample_dims, + random_state=random_state, ) - - pit_upper = pit_lower + pit_sup_addition[var].values - random_value = rng.uniform(pit_lower, pit_upper) - loo_pit_values[var] = observed_data[var].copy(data=random_value) else: - loo_pit_values[var] = observed_data[var].copy(data=pit_lower) + log_ratios = -log_likelihood[var] + pit_values, _ = pred.azstats.loo_pit( + y_obs=obs, + log_ratios=log_ratios, + r_eff=r_eff, + sample_dims=sample_dims, + random_state=random_state, + ) + loo_pit_values[var] = pit_values return loo_pit_values diff --git a/src/arviz_stats/loo/loo_score.py b/src/arviz_stats/loo/loo_score.py index b376610..9f1e13e 100644 --- a/src/arviz_stats/loo/loo_score.py +++ b/src/arviz_stats/loo/loo_score.py @@ -17,11 +17,11 @@ def loo_score( data, var_name=None, - log_weights=None, - pareto_k=None, kind="crps", pointwise=False, round_to=None, + log_weights=None, + pareto_k=None, ): r"""Compute PWM-based CRPS/SCRPS with PSIS-LOO-CV weights. @@ -55,14 +55,6 @@ def loo_score( The name of the variable in the log_likelihood group to use. If None, the first variable in ``observed_data`` is used and assumed to match ``log_likelihood`` and ``posterior_predictive`` names. - log_weights : DataArray, optional - Smoothed log weights for PSIS-LOO-CV. Must have the same shape as the log-likelihood data. - Defaults to None. If not provided, they will be computed via PSIS-LOO-CV. Must be provided - together with ``pareto_k`` or both must be None. - pareto_k : DataArray, optional - Pareto tail indices corresponding to the PSIS smoothing. Same shape as the log-likelihood - data. If not provided, they will be computed via PSIS-LOO-CV. Must be provided together with - ``log_weights`` or both must be None. kind : str, default "crps" The kind of score to compute. Available options are: @@ -74,13 +66,18 @@ def loo_score( If integer, number of decimal places to round the result. If string of the form '2g' number of significant digits to round the result. Defaults to '2g'. Use None to return raw numbers. + log_weights : DataArray, optional + Pre-computed smoothed log weights from PSIS. Must be provided together with pareto_k. + If not provided, PSIS will be computed internally. + pareto_k : DataArray, optional + Pre-computed Pareto k-hat diagnostic values. Must be provided together with log_weights. Returns ------- namedtuple If ``pointwise`` is False (default), a namedtuple named ``CRPS`` or ``SCRPS`` with fields - ``mean`` and ``se``. If ``pointwise`` is True, the namedtuple also includes a ``pointwise`` - field with per-observation values. + ``mean`` and ``se``. If ``pointwise`` is True, the namedtuple also includes ``pointwise`` + and ``pareto_k`` fields. Examples -------- @@ -99,18 +96,6 @@ def loo_score( In [2]: loo_score(dt, kind="scrps") - We can also pass previously computed PSIS-LOO weights and return the pointwise values: - - .. ipython:: - :okwarning: - - In [3]: from arviz_stats import loo - ...: loo_data = loo(dt, pointwise=True) - ...: loo_score(dt, kind="crps", - ...: log_weights=loo_data.log_weights, - ...: pareto_k=loo_data.pareto_k, - ...: pointwise=True) - Notes ----- For a single observation with posterior-predictive draws :math:`x_1, \ldots, x_S` @@ -210,22 +195,22 @@ def loo_score( _validate_crps_input(y_pred, y_obs, log_likelihood, sample_dims=sample_dims, obs_dims=obs_dims) - if (log_weights is None) != (pareto_k is None): - raise ValueError( - "Both log_weights and pareto_k must be provided together or both must be None. " - "Only one was provided." + if log_weights is not None and pareto_k is not None: + pointwise_scores, pareto_k_out = y_pred.azstats.loo_score( + y_obs=y_obs, + log_weights=log_weights, + pareto_k=pareto_k, + kind=kind, + r_eff=r_eff, + sample_dims=sample_dims, ) - - if log_weights is None and pareto_k is None: - log_weights_da, pareto_k = log_likelihood.azstats.psislw(r_eff=r_eff, dim=sample_dims) else: - log_weights_da = log_weights - - pointwise_scores = y_pred.azstats.loo_score( - y_obs=y_obs, log_weights=log_weights_da, kind=kind, sample_dims=sample_dims - ) + log_ratios = -log_likelihood + pointwise_scores, pareto_k_out = y_pred.azstats.loo_score( + y_obs=y_obs, log_ratios=log_ratios, kind=kind, r_eff=r_eff, sample_dims=sample_dims + ) - _warn_pareto_k(pareto_k, n_samples) + _warn_pareto_k(pareto_k_out, n_samples) n_pts = int(np.prod([pointwise_scores.sizes[d] for d in pointwise_scores.dims])) mean = pointwise_scores.mean().values.item() @@ -233,10 +218,11 @@ def loo_score( name = "SCRPS" if kind == "scrps" else "CRPS" if pointwise: - return namedtuple(name, ["mean", "se", "pointwise"])( + return namedtuple(name, ["mean", "se", "pointwise", "pareto_k"])( round_num(mean, round_to), round_num(se, round_to), pointwise_scores, + pareto_k_out, ) return namedtuple(name, ["mean", "se"])( round_num(mean, round_to), diff --git a/tests/base/test_array.py b/tests/base/test_array.py index 174bdae..ceccddf 100644 --- a/tests/base/test_array.py +++ b/tests/base/test_array.py @@ -1,9 +1,7 @@ """Tests for array interface base functions.""" # pylint: disable=redefined-outer-name, no-self-use, protected-access -import numpy as np import pytest -from numpy.testing import assert_allclose, assert_array_equal from arviz_stats.base.array import ( BaseArray, @@ -13,11 +11,14 @@ from ..helpers import importorskip +np = importorskip("numpy") azb = importorskip("arviz_base") einstats = importorskip("xarray_einstats") xr = importorskip("xarray") -from arviz_stats import loo, loo_approximate_posterior, loo_score +from numpy.testing import assert_allclose, assert_array_equal + +from arviz_stats import loo, loo_approximate_posterior, loo_pit, loo_score from arviz_stats.loo.helper_loo import _get_r_eff, _prepare_loo_inputs from arviz_stats.utils import get_log_likelihood_dataset @@ -684,8 +685,8 @@ def test_loo_diff_chain_draw_axes(self, array_stats, rng): def test_loo_with_reff(self, array_stats, rng): ary = rng.normal(-2, 1, size=(4, 100)) - elpd_i_1, pareto_k_1, _ = array_stats.loo(ary, reff=0.5) - elpd_i_2, pareto_k_2, _ = array_stats.loo(ary, reff=1.0) + elpd_i_1, pareto_k_1, _ = array_stats.loo(ary, r_eff=0.5) + elpd_i_2, pareto_k_2, _ = array_stats.loo(ary, r_eff=1.0) assert not np.isclose(pareto_k_1, pareto_k_2) or not np.isclose(elpd_i_1, elpd_i_2) @@ -703,7 +704,7 @@ def test_loo_matches_xarray(self, array_stats, centered_eight): loo_xr = loo(centered_eight, pointwise=True, var_name="obs") elpd_i_array, pareto_k_array, p_loo_i_array = array_stats.loo( - log_lik.values, chain_axis=0, draw_axis=1, reff=reff + log_lik.values, chain_axis=0, draw_axis=1, r_eff=reff ) lppd_xr = einstats.stats.logsumexp( @@ -811,6 +812,58 @@ def test_loo_approximate_posterior_matches_xarray(self, array_stats, centered_ei assert_allclose(p_loo_i_array, p_loo_i_xr.values, rtol=1e-10) +class TestLOOR2: + def test_loo_r2_basic(self, array_stats, rng): + y_obs = rng.normal(size=(100,)) + ypred_loo = y_obs + rng.normal(0, 0.1, size=(100,)) + + r2_samples = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100) + + assert r2_samples.shape == (100,) + assert np.all(np.isfinite(r2_samples)) + assert np.all(r2_samples >= -1) + assert np.all(r2_samples <= 1) + + def test_loo_r2_perfect_prediction(self, array_stats, rng): + y_obs = rng.normal(size=(50,)) + ypred_loo = y_obs.copy() + + r2_samples = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100) + + assert r2_samples.shape == (100,) + assert np.all(r2_samples > 0.99) + + def test_loo_r2_n_simulations(self, array_stats, rng): + y_obs = rng.normal(size=(50,)) + ypred_loo = y_obs + rng.normal(0, 0.5, size=(50,)) + + r2_100 = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100) + r2_500 = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=500) + + assert r2_100.shape == (100,) + assert r2_500.shape == (500,) + + def test_loo_r2_random_state(self, array_stats, rng): + y_obs = rng.normal(size=(50,)) + ypred_loo = y_obs + rng.normal(0, 0.5, size=(50,)) + + r2_1 = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100, random_state=42) + r2_2 = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100, random_state=42) + r2_3 = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100, random_state=123) + + assert_allclose(r2_1, r2_2) + assert not np.allclose(r2_1, r2_3) + + def test_loo_r2_circular(self, array_stats, rng): + y_obs = rng.vonmises(0, 2, size=(50,)) + ypred_loo = y_obs + rng.vonmises(0, 10, size=(50,)) + + r2_samples = array_stats.loo_r2(y_obs, ypred_loo, n_simulations=100, circular=True) + + assert r2_samples.shape == (100,) + assert np.all(np.isfinite(r2_samples)) + + class TestLOOScore: @pytest.mark.parametrize("kind", ["crps", "scrps"]) def test_loo_score_basic(self, array_stats, centered_eight, kind): @@ -826,7 +879,12 @@ def test_loo_score_basic(self, array_stats, centered_eight, kind): log_weights = log_weights_xr.values[:, :, :1] scores = array_stats.loo_score( - y_pred, y_obs, log_weights, kind=kind, chain_axis=0, draw_axis=1 + y_pred, + y_obs, + log_weights, + kind=kind, + chain_axis=0, + draw_axis=1, ) assert scores.shape == (1,) @@ -845,7 +903,12 @@ def test_loo_score_multiple_obs(self, array_stats, centered_eight): log_weights = log_weights_xr.values scores = array_stats.loo_score( - y_pred, y_obs, log_weights, kind="crps", chain_axis=0, draw_axis=1 + y_pred, + y_obs, + log_weights, + kind="crps", + chain_axis=0, + draw_axis=1, ) assert scores.shape == (8,) @@ -867,7 +930,12 @@ def test_loo_score_chain_axis_none(self, array_stats, centered_eight): log_weights_flat = log_weights.reshape(-1, log_weights.shape[-1]) scores = array_stats.loo_score( - y_pred_flat, y_obs, log_weights_flat, kind="crps", chain_axis=None, draw_axis=0 + y_pred_flat, + y_obs, + log_weights_flat, + kind="crps", + chain_axis=None, + draw_axis=0, ) assert scores.shape == (8,) @@ -889,7 +957,12 @@ def test_loo_score_diff_axes(self, array_stats, centered_eight): log_weights_reorder = np.transpose(log_weights, (2, 0, 1)) scores = array_stats.loo_score( - y_pred_reorder, y_obs, log_weights_reorder, kind="crps", chain_axis=1, draw_axis=2 + y_pred_reorder, + y_obs, + log_weights_reorder, + kind="crps", + chain_axis=1, + draw_axis=2, ) assert scores.shape == (8,) @@ -911,7 +984,484 @@ def test_loo_score_matches_xarray(self, array_stats, centered_eight, kind): log_weights = log_weights_xr.values scores_array = array_stats.loo_score( - y_pred, y_obs, log_weights, kind=kind, chain_axis=0, draw_axis=1 + y_pred, + y_obs, + log_weights, + kind=kind, + chain_axis=0, + draw_axis=1, ) assert_allclose(scores_array, loo_score_xr.pointwise.values, rtol=1e-10) + + +class TestLooPit: + def test_loo_pit_basic(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values[:, :, :1] + y_obs = centered_eight.observed_data["obs"].values[:1] + log_weights = log_weights_xr.values[:, :, :1] + + pit_values = array_stats.loo_pit( + y_pred, + y_obs, + log_weights, + chain_axis=0, + draw_axis=1, + ) + + assert pit_values.shape == (1,) + assert np.all(pit_values >= 0) + assert np.all(pit_values <= 1) + + def test_loo_pit_multiple_obs(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + y_obs = centered_eight.observed_data["obs"].values + log_weights = log_weights_xr.values + + pit_values = array_stats.loo_pit( + y_pred, + y_obs, + log_weights, + chain_axis=0, + draw_axis=1, + ) + + assert pit_values.shape == (8,) + assert np.all(pit_values >= 0) + assert np.all(pit_values <= 1) + + def test_loo_pit_chain_axis_none(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + y_obs = centered_eight.observed_data["obs"].values + log_weights = log_weights_xr.values + + y_pred_flat = y_pred.reshape(-1, y_pred.shape[-1]) + log_weights_flat = log_weights.reshape(-1, log_weights.shape[-1]) + + pit_values = array_stats.loo_pit( + y_pred_flat, + y_obs, + log_weights_flat, + chain_axis=None, + draw_axis=0, + ) + + assert pit_values.shape == (8,) + assert np.all(pit_values >= 0) + assert np.all(pit_values <= 1) + + def test_loo_pit_diff_axes(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + y_obs = centered_eight.observed_data["obs"].values + log_weights = log_weights_xr.values + + y_pred_reorder = np.transpose(y_pred, (2, 0, 1)) + log_weights_reorder = np.transpose(log_weights, (2, 0, 1)) + + pit_values = array_stats.loo_pit( + y_pred_reorder, + y_obs, + log_weights_reorder, + chain_axis=1, + draw_axis=2, + ) + + assert pit_values.shape == (8,) + assert np.all(pit_values >= 0) + assert np.all(pit_values <= 1) + + def test_loo_pit_matches_xarray(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + loo_pit_xr = loo_pit(centered_eight) + + y_pred = centered_eight.posterior_predictive["obs"].values + y_obs = centered_eight.observed_data["obs"].values + log_weights = log_weights_xr.values + + pit_values_array = array_stats.loo_pit( + y_pred, + y_obs, + log_weights, + chain_axis=0, + draw_axis=1, + ) + + assert_allclose(pit_values_array, loo_pit_xr["obs"].values, rtol=1e-10) + + def test_loo_pit_discrete_randomization(self, array_stats, rng): + n_chains, n_draws, n_obs = 2, 100, 5 + y_pred = rng.integers(0, 10, size=(n_chains, n_draws, n_obs)).astype(float) + y_obs = np.array([3.0, 5.0, 7.0, 2.0, 8.0]) + + log_weights = rng.normal(size=(n_chains, n_draws, n_obs)) + + pit_values = array_stats.loo_pit( + y_pred, + y_obs, + log_weights, + chain_axis=0, + draw_axis=1, + ) + + assert pit_values.shape == (n_obs,) + assert np.all(pit_values >= 0) + assert np.all(pit_values <= 1) + + has_ties = False + for i in range(n_obs): + if np.any(y_pred[:, :, i] == y_obs[i]): + has_ties = True + break + + assert has_ties, "Test data should have ties for this test to be meaningful" + + def test_loo_pit_random_state_reproducibility(self, array_stats, rng): + n_chains, n_draws, n_obs = 2, 100, 5 + y_pred = rng.integers(0, 10, size=(n_chains, n_draws, n_obs)).astype(float) + y_obs = np.array([3.0, 5.0, 7.0, 2.0, 8.0]) + log_weights = rng.normal(size=(n_chains, n_draws, n_obs)) + + pit_1 = array_stats.loo_pit( + y_pred, y_obs, log_weights, chain_axis=0, draw_axis=1, random_state=42 + ) + pit_2 = array_stats.loo_pit( + y_pred, y_obs, log_weights, chain_axis=0, draw_axis=1, random_state=42 + ) + pit_3 = array_stats.loo_pit( + y_pred, y_obs, log_weights, chain_axis=0, draw_axis=1, random_state=123 + ) + + assert_allclose(pit_1, pit_2) + assert not np.allclose(pit_1, pit_3) + + def test_loo_pit_single_observation(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values[:, :, :1] + y_obs = centered_eight.observed_data["obs"].values[:1] + log_weights = log_weights_xr.values[:, :, :1] + + pit_values = array_stats.loo_pit(y_pred, y_obs, log_weights, chain_axis=0, draw_axis=1) + + assert pit_values.shape == (1,) + assert 0 <= pit_values[0] <= 1 + + +class TestLooExpectation: + def test_loo_expectation(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values[:, :, :1] + log_weights = log_weights_xr.values[:, :, :1] + + expectation = array_stats.loo_expectation( + y_pred, + log_weights, + kind="mean", + chain_axis=0, + draw_axis=1, + ) + + assert expectation.shape == (1,) + assert np.all(np.isfinite(expectation)) + + def test_loo_expectation_multiple_obs(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + expectation = array_stats.loo_expectation( + y_pred, + log_weights, + kind="mean", + chain_axis=0, + draw_axis=1, + ) + + assert expectation.shape == (8,) + assert np.all(np.isfinite(expectation)) + + @pytest.mark.parametrize("kind", ["mean", "median", "var", "sd"]) + def test_loo_expectation_kinds(self, array_stats, centered_eight, kind): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + expectation = array_stats.loo_expectation( + y_pred, + log_weights, + kind=kind, + chain_axis=0, + draw_axis=1, + ) + + assert expectation.shape == (8,) + assert np.all(np.isfinite(expectation)) + if kind in ("var", "sd"): + assert np.all(expectation >= 0) + + @pytest.mark.parametrize("kind", ["circular_mean", "circular_var", "circular_sd"]) + def test_loo_expectation_circular_kinds(self, array_stats, rng, kind): + n_chains, n_draws, n_obs = 2, 100, 5 + angles = rng.uniform(-np.pi, np.pi, size=(n_chains, n_draws, n_obs)) + log_weights = rng.normal(size=(n_chains, n_draws, n_obs)) + + expectation = array_stats.loo_expectation( + angles, + log_weights, + kind=kind, + chain_axis=0, + draw_axis=1, + ) + + assert expectation.shape == (n_obs,) + assert np.all(np.isfinite(expectation)) + if kind == "circular_mean": + assert np.all(expectation >= -np.pi) + assert np.all(expectation <= np.pi) + elif kind == "circular_var": + assert np.all(expectation >= 0) + assert np.all(expectation <= 1) + + def test_loo_expectation_chain_axis_none(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + y_pred_flat = y_pred.reshape(-1, y_pred.shape[-1]) + log_weights_flat = log_weights.reshape(-1, log_weights.shape[-1]) + + expectation = array_stats.loo_expectation( + y_pred_flat, + log_weights_flat, + kind="mean", + chain_axis=None, + draw_axis=0, + ) + + assert expectation.shape == (8,) + assert np.all(np.isfinite(expectation)) + + def test_loo_expectation_diff_axes(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + y_pred_reorder = np.transpose(y_pred, (2, 0, 1)) + log_weights_reorder = np.transpose(log_weights, (2, 0, 1)) + + expectation = array_stats.loo_expectation( + y_pred_reorder, + log_weights_reorder, + kind="mean", + chain_axis=1, + draw_axis=2, + ) + + assert expectation.shape == (8,) + assert np.all(np.isfinite(expectation)) + + +class TestLooQuantile: + def test_loo_quantile_basic(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values[:, :, :1] + log_weights = log_weights_xr.values[:, :, :1] + + quantile = array_stats.loo_quantile( + y_pred, + log_weights, + prob=0.5, + chain_axis=0, + draw_axis=1, + ) + + assert quantile.shape == (1,) + assert np.all(np.isfinite(quantile)) + + def test_loo_quantile_multiple_obs(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + quantile = array_stats.loo_quantile( + y_pred, + log_weights, + prob=0.5, + chain_axis=0, + draw_axis=1, + ) + + assert quantile.shape == (8,) + assert np.all(np.isfinite(quantile)) + + @pytest.mark.parametrize("prob", [0.1, 0.25, 0.5, 0.75, 0.9]) + def test_loo_quantile_probs(self, array_stats, centered_eight, prob): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + quantile = array_stats.loo_quantile( + y_pred, + log_weights, + prob=prob, + chain_axis=0, + draw_axis=1, + ) + + assert quantile.shape == (8,) + assert np.all(np.isfinite(quantile)) + + def test_loo_quantile_ordering(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + q25 = array_stats.loo_quantile(y_pred, log_weights, prob=0.25, chain_axis=0, draw_axis=1) + q50 = array_stats.loo_quantile(y_pred, log_weights, prob=0.50, chain_axis=0, draw_axis=1) + q75 = array_stats.loo_quantile(y_pred, log_weights, prob=0.75, chain_axis=0, draw_axis=1) + + assert np.all(q25 <= q50) + assert np.all(q50 <= q75) + + def test_loo_quantile_chain_axis_none(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + y_pred_flat = y_pred.reshape(-1, y_pred.shape[-1]) + log_weights_flat = log_weights.reshape(-1, log_weights.shape[-1]) + + quantile = array_stats.loo_quantile( + y_pred_flat, + log_weights_flat, + prob=0.5, + chain_axis=None, + draw_axis=0, + ) + + assert quantile.shape == (8,) + assert np.all(np.isfinite(quantile)) + + def test_loo_quantile_diff_axes(self, array_stats, centered_eight): + log_lik = get_log_likelihood_dataset(centered_eight, var_names="obs")["obs"] + n_samples = log_lik.chain.size * log_lik.draw.size + reff = _get_r_eff(centered_eight, n_samples) + + log_weights_ds, _ = log_lik.azstats.psislw(r_eff=reff) + log_weights_xr = log_weights_ds.transpose("chain", "draw", "school") + + y_pred = centered_eight.posterior_predictive["obs"].values + log_weights = log_weights_xr.values + + y_pred_reorder = np.transpose(y_pred, (2, 0, 1)) + log_weights_reorder = np.transpose(log_weights, (2, 0, 1)) + + quantile = array_stats.loo_quantile( + y_pred_reorder, + log_weights_reorder, + prob=0.5, + chain_axis=1, + draw_axis=2, + ) + + assert quantile.shape == (8,) + assert np.all(np.isfinite(quantile)) diff --git a/tests/loo/test_loo_expectations.py b/tests/loo/test_loo_expectations.py index a354216..a8d3e53 100644 --- a/tests/loo/test_loo_expectations.py +++ b/tests/loo/test_loo_expectations.py @@ -1,16 +1,18 @@ """Test expectations functions for PSIS-LOO-CV.""" # pylint: disable=redefined-outer-name, unused-argument -import numpy as np import pytest -from numpy.testing import assert_allclose, assert_almost_equal, assert_array_equal from ..helpers import importorskip +np = importorskip("numpy") azb = importorskip("arviz_base") -from arviz_stats import loo, loo_expectations, loo_metrics, loo_r2 -from arviz_stats.utils import ELPDData +from numpy.testing import assert_allclose, assert_almost_equal, assert_array_equal + +from arviz_stats import loo_expectations, loo_metrics, loo_r2 +from arviz_stats.loo.helper_loo import _get_r_eff +from arviz_stats.utils import get_log_likelihood_dataset def test_loo_expectations_invalid_kind(centered_eight): @@ -28,31 +30,11 @@ def test_loo_expectations_invalid_var_name(centered_eight): loo_expectations(centered_eight, var_name="nonexistent") -def test_loo_expectations_elpddata_without_log_weights(centered_eight): - np.random.default_rng(42) - - loo_result_no_weights = ELPDData( - elpd=-30.0, - se=3.0, - p=2.0, - good_k=0.7, - n_samples=100, - n_data_points=8, - warning=False, - kind="loo", - scale="log", - log_weights=None, - ) - - with pytest.raises(ValueError, match="ELPDData object does not contain log_weights"): - loo_expectations(centered_eight, log_weights=loo_result_no_weights) - - @pytest.mark.parametrize( "kind, probs, expected_vals", [ ("mean", None, 3.81), - ("quantile", [0.25, 0.75], [-6.26, 14.44]), + ("quantile", [0.25, 0.75], [-6.27, 14.44]), ], ) def test_loo_expectations(centered_eight, kind, probs, expected_vals): @@ -104,18 +86,6 @@ def test_loo_expectations_khat(centered_eight, datatree, kind): loo_expectations(datatree, var_name="y", kind=kind, probs=probs) -def test_log_weights_input_formats(centered_eight): - loo_result = loo(centered_eight, pointwise=True) - log_weights_da = loo_result.log_weights - - loo_exp_da, khat_da = loo_expectations(centered_eight, kind="mean", log_weights=log_weights_da) - loo_exp_elpddata, khat_elpddata = loo_expectations( - centered_eight, kind="mean", log_weights=loo_result - ) - assert_array_equal(loo_exp_da.values, loo_exp_elpddata.values) - assert_array_equal(khat_da.values, khat_elpddata.values) - - @pytest.mark.parametrize("kind", ["median", "sd"]) def test_loo_expectations_median_sd(centered_eight, kind): result, khat = loo_expectations(centered_eight, kind=kind) @@ -125,6 +95,9 @@ def test_loo_expectations_median_sd(centered_eight, kind): assert np.all(np.isfinite(result.values)) assert np.all(np.isfinite(khat.values)) + if kind == "sd": + assert np.all(result.values >= 0) + def test_loo_expectations_single_quantile(centered_eight): result, khat = loo_expectations(centered_eight, kind="quantile", probs=0.5) @@ -183,7 +156,7 @@ def test_loo_expectations_with_explicit_var_name(centered_eight): @pytest.mark.parametrize("kind", ["mae", "mse", "rmse"]) def test_loo_metrics(centered_eight, kind): - result = loo_metrics(centered_eight, kind=kind) + result = loo_metrics(centered_eight, kind=kind, round_to=2) assert hasattr(result, "_fields") assert hasattr(result, "mean") @@ -192,18 +165,8 @@ def test_loo_metrics(centered_eight, kind): assert isinstance(result.se, int | float | str) -def test_loo_metrics_with_log_weights(centered_eight): - loo_result = loo(centered_eight, pointwise=True) - - result_with_weights = loo_metrics(centered_eight, kind="rmse", log_weights=loo_result) - result_without_weights = loo_metrics(centered_eight, kind="rmse") - - assert hasattr(result_with_weights, "mean") - assert hasattr(result_without_weights, "mean") - - def test_loo_metrics_explicit_var_name(centered_eight): - result = loo_metrics(centered_eight, var_name="obs", kind="mae") + result = loo_metrics(centered_eight, var_name="obs", kind="mae", round_to=2) assert hasattr(result, "mean") assert hasattr(result, "se") @@ -254,11 +217,67 @@ def test_loo_r2_ci_prob(datatree_regression, ci_prob): @pytest.mark.parametrize("kind", ["circular_mean", "circular_var", "circular_sd"]) def test_loo_expectations_circular(centered_eight, kind): - """Simple parametric checks for circular kinds: shape and finiteness of result and khat.""" - result, khat = loo_expectations(centered_eight, kind=kind) assert result.shape == (8,) assert khat.shape == (8,) assert np.all(np.isfinite(result.values)) assert np.all(np.isfinite(khat.values)) + + if kind == "circular_mean": + assert np.all(result.values >= -np.pi) + assert np.all(result.values <= np.pi) + elif kind == "circular_var": + assert np.all(result.values >= 0) + assert np.all(result.values <= 1) + else: # circular_sd + assert np.all(result.values >= 0) + + +@pytest.mark.parametrize("kind", ["mean", "var", "sd", "median"]) +def test_loo_expectations_precomputed_weights(centered_eight, kind): + result_auto, _ = loo_expectations(centered_eight, kind=kind) + + var_name = "obs" + log_likelihood = get_log_likelihood_dataset(centered_eight, var_names=var_name) + n_samples = log_likelihood[var_name].sizes["chain"] * log_likelihood[var_name].sizes["draw"] + r_eff = _get_r_eff(centered_eight, n_samples) + + log_weights_computed, pareto_k_computed = log_likelihood[var_name].azstats.psislw( + dim=["chain", "draw"], + r_eff=r_eff, + ) + + result_precomputed, _ = loo_expectations( + centered_eight, + kind=kind, + log_weights=log_weights_computed, + pareto_k=pareto_k_computed, + ) + + assert_allclose(result_precomputed.values, result_auto.values, rtol=1e-10) + + +def test_loo_expectations_quantile_precomputed_weights(centered_eight): + probs = [0.25, 0.75] + result_auto, _ = loo_expectations(centered_eight, kind="quantile", probs=probs) + + var_name = "obs" + log_likelihood = get_log_likelihood_dataset(centered_eight, var_names=var_name) + n_samples = log_likelihood[var_name].sizes["chain"] * log_likelihood[var_name].sizes["draw"] + r_eff = _get_r_eff(centered_eight, n_samples) + + log_weights_computed, pareto_k_computed = log_likelihood[var_name].azstats.psislw( + dim=["chain", "draw"], + r_eff=r_eff, + ) + + result_precomputed, _ = loo_expectations( + centered_eight, + kind="quantile", + probs=probs, + log_weights=log_weights_computed, + pareto_k=pareto_k_computed, + ) + + assert_allclose(result_precomputed.values, result_auto.values, rtol=1e-10) diff --git a/tests/loo/test_loo_pit.py b/tests/loo/test_loo_pit.py index 3c1feca..a258007 100644 --- a/tests/loo/test_loo_pit.py +++ b/tests/loo/test_loo_pit.py @@ -1,16 +1,19 @@ """Test probability integral transform for PSIS-LOO-CV.""" # pylint: disable=redefined-outer-name, unused-argument -import numpy as np import pytest -from numpy.testing import assert_array_equal from ..helpers import importorskip +np = importorskip("numpy") +xr = importorskip("xarray") azb = importorskip("arviz_base") -from arviz_stats import loo, loo_pit -from arviz_stats.utils import ELPDData, get_log_likelihood_dataset +from numpy.testing import assert_almost_equal + +from arviz_stats import loo_pit +from arviz_stats.loo.helper_loo import _get_r_eff +from arviz_stats.utils import get_log_likelihood_dataset def test_loo_pit_invalid_var_name(centered_eight): @@ -18,45 +21,20 @@ def test_loo_pit_invalid_var_name(centered_eight): loo_pit(centered_eight, var_names="nonexistent") -def test_loo_pit_elpddata_without_log_weights(centered_eight): - loo_result_no_weights = ELPDData( - elpd=-30.0, - se=3.0, - p=2.0, - good_k=0.7, - n_samples=100, - n_data_points=8, - warning=False, - kind="loo", - scale="log", - log_weights=None, - ) - - with pytest.raises(ValueError, match="ELPDData object does not contain log_weights"): - loo_pit(centered_eight, log_weights=loo_result_no_weights) - - @pytest.mark.parametrize( "args", [ {}, {"var_names": ["obs"]}, - {"log_weights": "arr"}, - {"var_names": ["obs"]}, {"var_names": "obs"}, ], ) def test_loo_pit(centered_eight, args): var_names = args.get("var_names", None) - log_weights = args.get("log_weights", None) - - if log_weights == "arr": - log_weights = get_log_likelihood_dataset(centered_eight, var_names=var_names) loo_pit_values = loo_pit( centered_eight, var_names=var_names, - log_weights=log_weights, ) assert np.all(loo_pit_values >= 0) assert np.all(loo_pit_values <= 1) @@ -70,15 +48,6 @@ def test_loo_pit_discrete(centered_eight): assert np.all(loo_pit_values <= 1) -def test_log_weights_input_formats(centered_eight): - loo_result = loo(centered_eight, pointwise=True) - log_weights_da = loo_result.log_weights - - loo_pit_da = loo_pit(centered_eight, log_weights=log_weights_da) - loo_pit_elpddata = loo_pit(centered_eight, log_weights=loo_result) - assert_array_equal(loo_pit_da["obs"].values, loo_pit_elpddata["obs"].values) - - def test_loo_pit_all_var_names(centered_eight): result = loo_pit(centered_eight) @@ -148,12 +117,61 @@ def test_loo_pit_multidimensional(): assert np.all(result["y"].values <= 1) -def test_loo_pit_with_precomputed_log_weights(centered_eight): - loo_result = loo(centered_eight, pointwise=True) +def test_loo_pit_precomputed_weights(centered_eight): + result_auto = loo_pit(centered_eight) + + var_names = ["obs"] + log_likelihood = get_log_likelihood_dataset(centered_eight, var_names=var_names) + n_samples = log_likelihood.chain.size * log_likelihood.draw.size + r_eff = _get_r_eff(centered_eight, n_samples) + + log_weights_computed, pareto_k_computed = log_likelihood["obs"].azstats.psislw( + dim=["chain", "draw"], + r_eff=r_eff, + ) + + log_weights_ds = xr.Dataset({"obs": log_weights_computed}) + pareto_k_ds = xr.Dataset({"obs": pareto_k_computed}) + + result_precomputed = loo_pit( + centered_eight, + var_names="obs", + log_weights=log_weights_ds, + pareto_k=pareto_k_ds, + ) + + assert_almost_equal(result_precomputed["obs"].values, result_auto["obs"].values, decimal=10) + + +def test_loo_pit_random_state_reproducibility(rng): + discrete_data = azb.from_dict( + { + "posterior": {"mu": rng.normal(size=(2, 50))}, + "posterior_predictive": {"y": rng.integers(0, 10, size=(2, 50, 8)).astype(float)}, + "log_likelihood": {"y": rng.normal(size=(2, 50, 8))}, + "observed_data": {"y": np.array([3.0, 5.0, 7.0, 2.0, 8.0, 4.0, 6.0, 1.0])}, + } + ) + + result_1 = loo_pit(discrete_data, random_state=42) + result_2 = loo_pit(discrete_data, random_state=42) + result_3 = loo_pit(discrete_data, random_state=123) + + assert_almost_equal(result_1["y"].values, result_2["y"].values, decimal=10) + assert not np.allclose(result_1["y"].values, result_3["y"].values) + + +def test_loo_pit_random_state_with_discrete(rng): + discrete_data = azb.from_dict( + { + "posterior": {"mu": rng.normal(size=(2, 50))}, + "posterior_predictive": {"y": rng.integers(0, 10, size=(2, 50, 8)).astype(float)}, + "log_likelihood": {"y": rng.normal(size=(2, 50, 8))}, + "observed_data": {"y": np.array([3.0, 5.0, 7.0, 2.0, 8.0, 4.0, 6.0, 1.0])}, + } + ) - result_with_weights = loo_pit(centered_eight, log_weights=loo_result) - result_without_weights = loo_pit(centered_eight) + result_1 = loo_pit(discrete_data, random_state=42) + result_2 = loo_pit(discrete_data, random_state=42) - assert result_with_weights["obs"].shape == result_without_weights["obs"].shape - assert np.all(result_with_weights["obs"].values >= 0) - assert np.all(result_with_weights["obs"].values <= 1) + assert_almost_equal(result_1["y"].values, result_2["y"].values, decimal=10) diff --git a/tests/loo/test_loo_score.py b/tests/loo/test_loo_score.py index 6dc3ef1..3df7f4b 100644 --- a/tests/loo/test_loo_score.py +++ b/tests/loo/test_loo_score.py @@ -1,15 +1,17 @@ """Test score functions for PSIS-LOO-CV.""" # pylint: disable=redefined-outer-name, unused-argument -import numpy as np import pytest -from numpy.testing import assert_almost_equal from ..helpers import importorskip +np = importorskip("numpy") azb = importorskip("arviz_base") -from arviz_stats import loo, loo_score +from numpy.testing import assert_almost_equal + +from arviz_stats import loo_score +from arviz_stats.loo.helper_loo import _get_r_eff, _prepare_loo_inputs def test_loo_score_invalid_kind(centered_eight): @@ -22,16 +24,6 @@ def test_loo_score_invalid_var_name(centered_eight): loo_score(centered_eight, var_name="nonexistent") -def test_loo_score_mismatched_log_weights_pareto_k(centered_eight): - loo_result = loo(centered_eight, pointwise=True) - - with pytest.raises(ValueError, match="Both log_weights and pareto_k must be provided together"): - loo_score(centered_eight, log_weights=loo_result.log_weights, pareto_k=None) - - with pytest.raises(ValueError, match="Both log_weights and pareto_k must be provided together"): - loo_score(centered_eight, log_weights=None, pareto_k=loo_result.pareto_k) - - @pytest.mark.parametrize("kind", ["crps", "scrps"]) def test_loo_score_basic(centered_eight, kind): result = loo_score(centered_eight, kind=kind) @@ -40,22 +32,10 @@ def test_loo_score_basic(centered_eight, kind): assert hasattr(result, "se") assert np.isfinite(result.mean) assert np.isfinite(result.se) + assert result.se >= 0 - -@pytest.mark.parametrize("kind", ["crps", "scrps"]) -def test_loo_score_log_weights(centered_eight, kind): - loo_result = loo(centered_eight, pointwise=True) - - r1 = loo_score(centered_eight, kind=kind) - r2 = loo_score( - centered_eight, - log_weights=loo_result.log_weights, - pareto_k=loo_result.pareto_k, - kind=kind, - ) - - assert_almost_equal(r1.mean, r2.mean, decimal=10) - assert_almost_equal(r1.se, r2.se, decimal=10) + if kind == "crps": + assert result.mean <= 0 @pytest.mark.parametrize("kind", ["crps", "scrps"]) @@ -88,10 +68,16 @@ def test_loo_score_pointwise(centered_eight, kind): assert hasattr(result, "mean") assert hasattr(result, "se") assert hasattr(result, "pointwise") + assert hasattr(result, "pareto_k") assert np.isfinite(result.mean) assert np.isfinite(result.se) assert result.pointwise.shape == (8,) assert np.all(np.isfinite(result.pointwise.values)) + assert result.pareto_k.shape == (8,) + assert np.all(np.isfinite(result.pareto_k.values)) + + if kind == "crps": + assert np.all(result.pointwise.values <= 0) def test_loo_score_namedtuple_names(centered_eight): @@ -170,3 +156,31 @@ def test_loo_score_pointwise_shape_multidim(): result = loo_score(multi_var_data, kind="crps", pointwise=True) assert result.pointwise.shape == (5,) + + +@pytest.mark.parametrize("kind", ["crps", "scrps"]) +def test_loo_score_precomputed_weights(centered_eight, kind): + result_auto = loo_score(centered_eight, kind=kind, pointwise=True) + + loo_inputs = _prepare_loo_inputs(centered_eight, None) + log_likelihood = loo_inputs.log_likelihood + + log_weights_computed, pareto_k_computed = log_likelihood.azstats.psislw( + dim=["chain", "draw"], + r_eff=_get_r_eff(centered_eight, loo_inputs.n_samples), + ) + + result_precomputed = loo_score( + centered_eight, + kind=kind, + pointwise=True, + log_weights=log_weights_computed, + pareto_k=pareto_k_computed, + ) + + assert_almost_equal(result_precomputed.mean, result_auto.mean, decimal=10) + assert_almost_equal(result_precomputed.se, result_auto.se, decimal=10) + assert_almost_equal( + result_precomputed.pointwise.values, result_auto.pointwise.values, decimal=10 + ) + assert_almost_equal(result_precomputed.pareto_k.values, result_auto.pareto_k.values, decimal=10)