44import numpy as np
55import pytest
66import scipy .stats as stats
7- from aesara import function
7+ from aesara import config , function
88
99from aeppl .dists import dirac_delta
1010from aeppl .logprob import ParameterValueError , icdf , logcdf , logprob
@@ -70,7 +70,10 @@ def scipy_logprob_tester(
7070 else :
7171 raise ValueError (f"test must be one of (logprob, logcdf, icdf), got { test } " )
7272
73- aesara_res_val = aesara_res .eval (dist_params )
73+ aesara_fn = function (
74+ tuple (dist_params .keys ()), aesara_res , on_unused_input = "ignore"
75+ )
76+ aesara_res_val = aesara_fn (* tuple (dist_params .values ()))
7477
7578 numpy_res = np .asarray (test_fn (obs , * dist_params .values ()))
7679
@@ -452,9 +455,14 @@ def scipy_logprob(obs, c):
452455
453456
454457@pytest .mark .parametrize (
455- "dist_params, obs, size, error " ,
458+ "dist_params, obs, size, param_error " ,
456459 [
457- ((- 1 , - 1.0 ), np .array ([- np .pi , - 0.5 , 0 , 1 , np .pi ], dtype = np .float64 ), (), True ),
460+ (
461+ (- 1 , - 1.0 ),
462+ np .array ([- np .pi , - 0.5 , 0 , 1 , np .pi ], dtype = np .float64 ),
463+ (),
464+ True ,
465+ ),
458466 (
459467 (1.5 , 10.5 ),
460468 np .array ([- np .pi , - 0.5 , 0 , 1 , np .pi ], dtype = np .float64 ),
@@ -467,25 +475,36 @@ def scipy_logprob(obs, c):
467475 (2 , 3 ),
468476 False ,
469477 ),
470- ((10 , 1.0 ), np .array ([- np .pi , - 0.5 , 0 , 1 , np .pi ], dtype = np .float64 ), (), False ),
478+ (
479+ (10 , 1.0 ),
480+ np .array ([- np .pi , - 0.5 , 0 , 1 , np .pi ], dtype = np .float64 ),
481+ (),
482+ False ,
483+ ),
471484 ],
472485)
473- def test_vonmises_logprob (dist_params , obs , size , error ):
486+ def test_vonmises_logprob (dist_params , obs , size , param_error ):
474487 dist_params_at , obs_at , size_at = create_aesara_params (dist_params , obs , size )
475488 dist_params = dict (zip (dist_params_at , dist_params ))
476489
477490 x = at .random .vonmises (* dist_params_at , size = size_at )
478491
479- cm = contextlib .suppress () if not error else pytest .raises (ParameterValueError )
492+ param_cm = (
493+ contextlib .suppress () if not param_error else pytest .raises (ParameterValueError )
494+ )
495+ i0_cm = (
496+ contextlib .suppress ()
497+ if not param_error
498+ else pytest .raises (
499+ UserWarning , match = "The Op i0 does not provide a C implementation"
500+ )
501+ )
480502
481503 def scipy_logprob (obs , mu , kappa ):
482504 return stats .vonmises .logpdf (obs , kappa , loc = mu )
483505
484- with pytest .raises (
485- UserWarning , match = "The Op i0 does not provide a C implementation"
486- ):
487- with cm :
488- scipy_logprob_tester (x , obs , dist_params , test_fn = scipy_logprob )
506+ with config .change_flags (on_opt_error = "warn" ), param_cm , i0_cm :
507+ scipy_logprob_tester (x , obs , dist_params , test_fn = scipy_logprob )
489508
490509
491510@pytest .mark .parametrize (
0 commit comments