Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions tests/base/test_stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,210 @@ def test_logsumexp_edge_b():
ary = np.array([1.0, 2.0, 3.0])
assert _logsumexp(ary, b=0) == -np.inf
assert _logsumexp(ary, b_inv=0) == np.inf


@pytest.mark.parametrize(
"shape,axis,expected_shape",
[
((4, 500, 20), (0, 1), (20,)),
((4, 500, 1), (0, 1), (1,)),
((4, 500, 8, 10), (0, 1), (8, 10)),
((4, 500, 20), (-3, -2), (20,)),
],
)
def test_logsumexp_loo_shapes(rng, shape, axis, expected_shape):
log_lik = rng.normal(loc=-5, scale=2, size=shape)
n_samples = shape[0] * shape[1]

scipy_result = logsumexp(log_lik, axis=axis, b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=axis, b=1 / n_samples)

assert scipy_result.shape == expected_shape
assert arviz_result.shape == expected_shape
assert_array_almost_equal(scipy_result, arviz_result, decimal=10)

max_log_lik = np.max(log_lik, axis=axis)
assert np.all(scipy_result <= max_log_lik + 1)
assert np.all(scipy_result >= -10)
assert np.all(arviz_result <= max_log_lik + 1)
assert np.all(arviz_result >= -10)


def test_logsumexp_loo_weights(rng):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))

log_weights = rng.normal(loc=-5, scale=1, size=(n_chains, n_draws, n_obs))
log_weights = log_weights - logsumexp(log_weights, axis=(0, 1), keepdims=True)

log_weighted = log_weights + log_lik

scipy_elpd_i = logsumexp(log_weighted, axis=(0, 1))
arviz_elpd_i = _logsumexp(log_weighted, axis=(0, 1))

assert scipy_elpd_i.shape == (n_obs,)
assert_array_almost_equal(scipy_elpd_i, arviz_elpd_i, decimal=10)

assert np.all(np.isfinite(scipy_elpd_i))
assert np.all(np.isfinite(arviz_elpd_i))
weights_sum = logsumexp(log_weights, axis=(0, 1))
assert_array_almost_equal(weights_sum, np.zeros(n_obs), decimal=10)


@pytest.mark.parametrize(
"loc,scale",
[
(-0.9, 0.1),
(-3.5, 1.5),
(-20, 5),
(-500, 50),
(-0.1, 0.01),
],
)
def test_logsumexp_loo_values(rng, loc, scale):
log_lik = rng.normal(loc=loc, scale=scale, size=(4, 500, 10))
n_samples = log_lik.shape[0] * log_lik.shape[1]

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

assert np.all(np.isfinite(scipy_result))
assert np.all(np.isfinite(arviz_result))
assert_array_almost_equal(scipy_result, arviz_result, decimal=8)

assert np.all(np.abs(scipy_result - loc) < scale * 5)
assert np.all(np.abs(arviz_result - loc) < scale * 5)


def test_logsumexp_loo_b_inv(rng):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b_inv=n_samples)

assert_array_almost_equal(scipy_result, arviz_result, decimal=10)


@pytest.mark.parametrize(
"n_chains,n_draws,n_obs",
[
(1, 1000, 10),
(2, 500, 10),
(8, 250, 10),
(4, 2000, 5),
],
)
def test_logsumexp_loo_varying_dims(rng, n_chains, n_draws, n_obs):
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

assert_array_almost_equal(scipy_result, arviz_result, decimal=10)


@pytest.mark.parametrize("keepdims", [True, False])
def test_logsumexp_loo_keepdims(rng, keepdims):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, keepdims=keepdims)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, keepdims=keepdims)

expected_shape = (1, 1, n_obs) if keepdims else (n_obs,)
assert scipy_result.shape == expected_shape
assert arviz_result.shape == expected_shape
assert_array_almost_equal(scipy_result, arviz_result, decimal=10)


def test_logsumexp_loo_copy(rng):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

result_copy = _logsumexp(log_lik.copy(), axis=(0, 1), b=1 / n_samples, copy=True)

log_lik_no_copy = log_lik.copy()
result_no_copy = _logsumexp(log_lik_no_copy, axis=(0, 1), b=1 / n_samples, copy=False)

assert_array_almost_equal(result_copy, result_no_copy, decimal=10)

log_lik_test = log_lik.copy()
_ = _logsumexp(log_lik_test, axis=(0, 1), b=1 / n_samples, copy=True)
assert_array_almost_equal(log_lik, log_lik_test)


def test_logsumexp_loo_out(rng):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

out = np.empty(n_obs)
result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples, out=out)

assert result is out

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
assert_array_almost_equal(out, scipy_result, decimal=10)


@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_logsumexp_loo_dtype(rng, dtype):
n_chains, n_draws, n_obs = 4, 500, 10
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs)).astype(dtype)
n_samples = n_chains * n_draws

result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

assert result.dtype == dtype


def test_logsumexp_loo_constant():
n_chains, n_draws, n_obs = 4, 500, 10

constant_value = -5.0
log_lik = np.full((n_chains, n_draws, n_obs), constant_value)
n_samples = n_chains * n_draws

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

expected = np.full(n_obs, constant_value)

assert_array_almost_equal(scipy_result, expected, decimal=10)
assert_array_almost_equal(arviz_result, expected, decimal=10)
assert_array_almost_equal(scipy_result, arviz_result, decimal=10)


def test_logsumexp_loo_known_values():
log_lik = np.array([[[0.0, -1.0], [-2.0, -3.0]]])
n_samples = 2

scipy_result = logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)
arviz_result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

expected_0 = np.log((np.exp(0.0) + np.exp(-2.0)) / 2)
expected_1 = np.log((np.exp(-1.0) + np.exp(-3.0)) / 2)
expected = np.array([expected_0, expected_1])

assert_array_almost_equal(scipy_result, expected, decimal=10)
assert_array_almost_equal(arviz_result, expected, decimal=10)
assert_array_almost_equal(scipy_result, arviz_result, decimal=10)


def test_logsumexp_loo_bounds(rng):
n_chains, n_draws, n_obs = 4, 500, 20
log_lik = rng.normal(loc=-5, scale=2, size=(n_chains, n_draws, n_obs))
n_samples = n_chains * n_draws

result = _logsumexp(log_lik, axis=(0, 1), b=1 / n_samples)

max_vals = np.max(log_lik, axis=(0, 1))
mean_vals = np.mean(log_lik, axis=(0, 1))

assert np.all(result <= max_vals)
assert np.all(result >= mean_vals - 5)