From 1e7924eb36547276d58c9dcfda13a2b026f8014e Mon Sep 17 00:00:00 2001 From: Jordan Deklerk <111652310+jordandeklerk@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:03:39 -0400 Subject: [PATCH] tests: add tests for numpy logsumexp --- tests/base/test_stats_utils.py | 207 +++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) diff --git a/tests/base/test_stats_utils.py b/tests/base/test_stats_utils.py index 7be056a..c8e9b6c 100644 --- a/tests/base/test_stats_utils.py +++ b/tests/base/test_stats_utils.py @@ -281,3 +281,210 @@ def test_logsumexp_edge_b(): ary = np.array([1.0, 2.0, 3.0]) assert _logsumexp(ary, b=0) == -np.inf assert _logsumexp(ary, b_inv=0) == np.inf + + +@pytest.mark.parametrize( + "shape,axis,expected_shape", + [ + ((4, 500, 20), (0, 1), (20,)), + ((4, 500, 1), (0, 1), (1,)), + ((4, 500, 8, 10), (0, 1), (8, 10)), + ((4, 500, 20), (-3, -2), (20,)), + ], +) +def test_logsumexp_loo_shapes(rng, shape, axis, expected_shape): + log_lik = rng.normal(loc=-5, scale=2, size=shape) + n_samples = shape[0] * shape[1] + + scipy_result = logsumexp(log_lik, axis=axis, b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=axis, b=1 / n_samples) + + assert scipy_result.shape == expected_shape + assert arviz_result.shape == expected_shape + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + max_log_lik = np.max(log_lik, axis=axis) + assert np.all(scipy_result <= max_log_lik + 1) + assert np.all(scipy_result >= -10) + assert np.all(arviz_result <= max_log_lik + 1) + assert np.all(arviz_result >= -10) + + +def test_logsumexp_loo_weights(rng): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + + log_weights = rng.normal(loc=-5, scale=1, size=(n_chains, n_draws, n_obs)) + log_weights = log_weights - logsumexp(log_weights, axis=(0, 1), keepdims=True) + + log_weighted = log_weights + log_lik + + scipy_elpd_i = logsumexp(log_weighted, axis=(0, 1)) + arviz_elpd_i = _logsumexp(log_weighted, axis=(0, 1)) + + assert scipy_elpd_i.shape == (n_obs,) + assert_array_almost_equal(scipy_elpd_i, arviz_elpd_i, decimal=10) + + assert np.all(np.isfinite(scipy_elpd_i)) + assert np.all(np.isfinite(arviz_elpd_i)) + weights_sum = logsumexp(log_weights, axis=(0, 1)) + assert_array_almost_equal(weights_sum, np.zeros(n_obs), decimal=10) + + +@pytest.mark.parametrize( + "loc,scale", + [ + (-0.9, 0.1), + (-3.5, 1.5), + (-20, 5), + (-500, 50), + (-0.1, 0.01), + ], +) +def test_logsumexp_loo_values(rng, loc, scale): + log_lik = rng.normal(loc=loc, scale=scale, size=(4, 500, 10)) + n_samples = log_lik.shape[0] * log_lik.shape[1] + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + assert np.all(np.isfinite(scipy_result)) + assert np.all(np.isfinite(arviz_result)) + assert_array_almost_equal(scipy_result, arviz_result, decimal=8) + + assert np.all(np.abs(scipy_result - loc) < scale * 5) + assert np.all(np.abs(arviz_result - loc) < scale * 5) + + +def test_logsumexp_loo_b_inv(rng): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b_inv=n_samples) + + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + +@pytest.mark.parametrize( + "n_chains,n_draws,n_obs", + [ + (1, 1000, 10), + (2, 500, 10), + (8, 250, 10), + (4, 2000, 5), + ], +) +def test_logsumexp_loo_varying_dims(rng, n_chains, n_draws, n_obs): + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + +@pytest.mark.parametrize("keepdims", [True, False]) +def test_logsumexp_loo_keepdims(rng, keepdims): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, keepdims=keepdims) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, keepdims=keepdims) + + expected_shape = (1, 1, n_obs) if keepdims else (n_obs,) + assert scipy_result.shape == expected_shape + assert arviz_result.shape == expected_shape + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + +def test_logsumexp_loo_copy(rng): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + result_copy = _logsumexp(log_lik.copy(), axis=(0, 1), b=1 / n_samples, copy=True) + + log_lik_no_copy = log_lik.copy() + result_no_copy = _logsumexp(log_lik_no_copy, axis=(0, 1), b=1 / n_samples, copy=False) + + assert_array_almost_equal(result_copy, result_no_copy, decimal=10) + + log_lik_test = log_lik.copy() + _ = _logsumexp(log_lik_test, axis=(0, 1), b=1 / n_samples, copy=True) + assert_array_almost_equal(log_lik, log_lik_test) + + +def test_logsumexp_loo_out(rng): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + out = np.empty(n_obs) + result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, out=out) + + assert result is out + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + assert_array_almost_equal(out, scipy_result, decimal=10) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_logsumexp_loo_dtype(rng, dtype): + n_chains, n_draws, n_obs = 4, 500, 10 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)).astype(dtype) + n_samples = n_chains * n_draws + + result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + assert result.dtype == dtype + + +def test_logsumexp_loo_constant(): + n_chains, n_draws, n_obs = 4, 500, 10 + + constant_value = -5.0 + log_lik = np.full((n_chains, n_draws, n_obs), constant_value) + n_samples = n_chains * n_draws + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + expected = np.full(n_obs, constant_value) + + assert_array_almost_equal(scipy_result, expected, decimal=10) + assert_array_almost_equal(arviz_result, expected, decimal=10) + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + +def test_logsumexp_loo_known_values(): + log_lik = np.array([[[0.0, -1.0], [-2.0, -3.0]]]) + n_samples = 2 + + scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + expected_0 = np.log((np.exp(0.0) + np.exp(-2.0)) / 2) + expected_1 = np.log((np.exp(-1.0) + np.exp(-3.0)) / 2) + expected = np.array([expected_0, expected_1]) + + assert_array_almost_equal(scipy_result, expected, decimal=10) + assert_array_almost_equal(arviz_result, expected, decimal=10) + assert_array_almost_equal(scipy_result, arviz_result, decimal=10) + + +def test_logsumexp_loo_bounds(rng): + n_chains, n_draws, n_obs = 4, 500, 20 + log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)) + n_samples = n_chains * n_draws + + result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples) + + max_vals = np.max(log_lik, axis=(0, 1)) + mean_vals = np.mean(log_lik, axis=(0, 1)) + + assert np.all(result <= max_vals) + assert np.all(result >= mean_vals - 5)