Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 5f5aaa4

Browse files
Explicitly compile functions in scipy_logprob_tester and update test_vonmises_logprob
1 parent a170f77 commit 5f5aaa4

File tree

1 file changed

+31
-12
lines changed

1 file changed

+31
-12
lines changed

tests/test_logprob.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66
import scipy.stats as stats
7-
from aesara import function
7+
from aesara import config, function
88

99
from aeppl.dists import dirac_delta
1010
from aeppl.logprob import ParameterValueError, icdf, logcdf, logprob
@@ -70,7 +70,10 @@ def scipy_logprob_tester(
7070
else:
7171
raise ValueError(f"test must be one of (logprob, logcdf, icdf), got {test}")
7272

73-
aesara_res_val = aesara_res.eval(dist_params)
73+
aesara_fn = function(
74+
tuple(dist_params.keys()), aesara_res, on_unused_input="ignore"
75+
)
76+
aesara_res_val = aesara_fn(*tuple(dist_params.values()))
7477

7578
numpy_res = np.asarray(test_fn(obs, *dist_params.values()))
7679

@@ -452,9 +455,14 @@ def scipy_logprob(obs, c):
452455

453456

454457
@pytest.mark.parametrize(
455-
"dist_params, obs, size, error",
458+
"dist_params, obs, size, param_error",
456459
[
457-
((-1, -1.0), np.array([-np.pi, -0.5, 0, 1, np.pi], dtype=np.float64), (), True),
460+
(
461+
(-1, -1.0),
462+
np.array([-np.pi, -0.5, 0, 1, np.pi], dtype=np.float64),
463+
(),
464+
True,
465+
),
458466
(
459467
(1.5, 10.5),
460468
np.array([-np.pi, -0.5, 0, 1, np.pi], dtype=np.float64),
@@ -467,25 +475,36 @@ def scipy_logprob(obs, c):
467475
(2, 3),
468476
False,
469477
),
470-
((10, 1.0), np.array([-np.pi, -0.5, 0, 1, np.pi], dtype=np.float64), (), False),
478+
(
479+
(10, 1.0),
480+
np.array([-np.pi, -0.5, 0, 1, np.pi], dtype=np.float64),
481+
(),
482+
False,
483+
),
471484
],
472485
)
473-
def test_vonmises_logprob(dist_params, obs, size, error):
486+
def test_vonmises_logprob(dist_params, obs, size, param_error):
474487
dist_params_at, obs_at, size_at = create_aesara_params(dist_params, obs, size)
475488
dist_params = dict(zip(dist_params_at, dist_params))
476489

477490
x = at.random.vonmises(*dist_params_at, size=size_at)
478491

479-
cm = contextlib.suppress() if not error else pytest.raises(ParameterValueError)
492+
param_cm = (
493+
contextlib.suppress() if not param_error else pytest.raises(ParameterValueError)
494+
)
495+
i0_cm = (
496+
contextlib.suppress()
497+
if not param_error
498+
else pytest.raises(
499+
UserWarning, match="The Op i0 does not provide a C implementation"
500+
)
501+
)
480502

481503
def scipy_logprob(obs, mu, kappa):
482504
return stats.vonmises.logpdf(obs, kappa, loc=mu)
483505

484-
with pytest.raises(
485-
UserWarning, match="The Op i0 does not provide a C implementation"
486-
):
487-
with cm:
488-
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
506+
with config.change_flags(on_opt_error="warn"), param_cm, i0_cm:
507+
scipy_logprob_tester(x, obs, dist_params, test_fn=scipy_logprob)
489508

490509

491510
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)