44import pytest
55import scipy as sp
66import scipy .special
7- from aesara .graph .basic import equal_computations
87from aesara .graph .fg import FunctionGraph
98from numdifftools import Jacobian
109
2221 TransformValuesMapping ,
2322 TransformValuesRewrite ,
2423 _default_transformed_rv ,
25- transformed_variable ,
2624)
2725from tests .utils import assert_no_rvs
2826
@@ -176,15 +174,13 @@ def test_transformed_logprob(at_dist, dist_params, sp_dist, size):
176174
177175 a = at_dist (* dist_params , size = size )
178176 a .name = "a"
179- a_value_var = at .tensor (a .dtype , shape = (None ,) * a .ndim )
180- a_value_var .name = "a_value"
181177
182178 b = at .random .normal (a , 1.0 )
183179 b .name = "b"
184180
185- transform_rewrite = TransformValuesRewrite ({a_value_var : DEFAULT_TRANSFORM })
186- res , (b_value_var ,) = joint_logprob (
187- b , realized = { a : a_value_var } , extra_rewrites = transform_rewrite
181+ transform_rewrite = TransformValuesRewrite ({a : DEFAULT_TRANSFORM })
182+ res , (b_value_var , a_value_var ) = joint_logprob (
183+ b , a , extra_rewrites = transform_rewrite
188184 )
189185
190186 test_val_rng = np .random .RandomState (3238 )
@@ -268,12 +264,10 @@ def a_backward_fn_(x):
268264@pytest .mark .parametrize ("use_jacobian" , [True , False ])
269265def test_simple_transformed_logprob_nojac (use_jacobian ):
270266 X_rv = at .random .halfnormal (0 , 3 , name = "X" )
271- x_vv = X_rv .clone ()
272- x_vv .name = "x"
273267
274- transform_rewrite = TransformValuesRewrite ({x_vv : DEFAULT_TRANSFORM })
275- tr_logp , _ = joint_logprob (
276- realized = { X_rv : x_vv } ,
268+ transform_rewrite = TransformValuesRewrite ({X_rv : DEFAULT_TRANSFORM })
269+ tr_logp , ( x_vv ,) = joint_logprob (
270+ X_rv ,
277271 extra_rewrites = transform_rewrite ,
278272 use_jacobian = use_jacobian ,
279273 )
@@ -321,19 +315,17 @@ def test_hierarchical_uniform_transform():
321315 upper_rv = at .random .uniform (9 , 10 , name = "upper" )
322316 x_rv = at .random .uniform (lower_rv , upper_rv , name = "x" )
323317
324- lower = lower_rv .clone ()
325- upper = upper_rv .clone ()
326- x = x_rv .clone ()
327-
328318 transform_rewrite = TransformValuesRewrite (
329319 {
330- lower : DEFAULT_TRANSFORM ,
331- upper : DEFAULT_TRANSFORM ,
332- x : DEFAULT_TRANSFORM ,
320+ lower_rv : DEFAULT_TRANSFORM ,
321+ upper_rv : DEFAULT_TRANSFORM ,
322+ x_rv : DEFAULT_TRANSFORM ,
333323 }
334324 )
335- logp , _ = joint_logprob (
336- realized = {lower_rv : lower , upper_rv : upper , x_rv : x },
325+ logp , (lower , upper , x ) = joint_logprob (
326+ lower_rv ,
327+ upper_rv ,
328+ x_rv ,
337329 extra_rewrites = transform_rewrite ,
338330 )
339331
@@ -346,20 +338,18 @@ def test_nondefault_transforms():
346338 scale_rv = at .random .uniform (- 1 , 1 , name = "scale" )
347339 x_rv = at .random .normal (loc_rv , scale_rv , name = "x" )
348340
349- loc = loc_rv .clone ()
350- scale = scale_rv .clone ()
351- x = x_rv .clone ()
352-
353341 transform_rewrite = TransformValuesRewrite (
354342 {
355- loc : None ,
356- scale : LogOddsTransform (),
357- x : LogTransform (),
343+ loc_rv : None ,
344+ scale_rv : LogOddsTransform (),
345+ x_rv : LogTransform (),
358346 }
359347 )
360348
361- logp , _ = joint_logprob (
362- realized = {loc_rv : loc , scale_rv : scale , x_rv : x },
349+ logp , (loc , scale , x ) = joint_logprob (
350+ loc_rv ,
351+ scale_rv ,
352+ x_rv ,
363353 extra_rewrites = transform_rewrite ,
364354 )
365355
@@ -391,12 +381,11 @@ def test_default_transform_multiout():
391381 # multiple outputs and no default output.
392382 sd = at .linalg .svd (at .eye (1 ))[1 ][0 ]
393383 x_rv = at .random .normal (0 , sd , name = "x" )
394- x = x_rv .clone ()
395384
396- transform_rewrite = TransformValuesRewrite ({x : DEFAULT_TRANSFORM })
385+ transform_rewrite = TransformValuesRewrite ({x_rv : DEFAULT_TRANSFORM })
397386
398- logp , _ = joint_logprob (
399- realized = { x_rv : x } ,
387+ logp , ( x ,) = joint_logprob (
388+ x_rv ,
400389 extra_rewrites = transform_rewrite ,
401390 )
402391
@@ -412,12 +401,11 @@ def test_nonexistent_default_transform():
412401 transform does not fail
413402 """
414403 x_rv = at .random .normal (name = "x" )
415- x = x_rv .clone ()
416404
417- transform_rewrite = TransformValuesRewrite ({x : DEFAULT_TRANSFORM })
405+ transform_rewrite = TransformValuesRewrite ({x_rv : DEFAULT_TRANSFORM })
418406
419- logp , _ = joint_logprob (
420- realized = { x_rv : x } ,
407+ logp , ( x ,) = joint_logprob (
408+ x_rv ,
421409 extra_rewrites = transform_rewrite ,
422410 )
423411
@@ -446,9 +434,8 @@ def test_original_values_output_dict():
446434 the logprob factor
447435 """
448436 p_rv = at .random .beta (1 , 1 , name = "p" )
449- p_vv = p_rv .clone ()
450437
451- tr = TransformValuesRewrite ({p_vv : DEFAULT_TRANSFORM })
438+ tr = TransformValuesRewrite ({p_rv : DEFAULT_TRANSFORM })
452439 logp_dict , _ = conditional_logprob (p_rv , extra_rewrites = tr )
453440
454441 assert p_rv in logp_dict
@@ -469,29 +456,28 @@ def test_mixture_transform():
469456 Y_rv = at .stack ([Y_1_rv , Y_2_rv ])[I_rv ]
470457 Y_rv .name = "Y"
471458
472- logp_no_trans , (y_vv , i_vv ) = joint_logprob (Y_rv , I_rv )
459+ logp , (y_vv , i_vv ) = joint_logprob (
460+ Y_rv ,
461+ I_rv ,
462+ )
473463
474- transform_rewrite = TransformValuesRewrite ({y_vv : LogTransform ()})
464+ transform_rewrite = TransformValuesRewrite ({Y_rv : LogOddsTransform ()})
475465
476466 with pytest .warns (None ) as record :
477467 # This shouldn't raise any warnings
478- logp_trans , _ = joint_logprob (
479- realized = {Y_rv : y_vv , I_rv : i_vv },
468+ logp_trans , (y_vv_trans , i_vv_trans ) = joint_logprob (
469+ Y_rv ,
470+ I_rv ,
480471 extra_rewrites = transform_rewrite ,
481472 use_jacobian = False ,
482473 )
483474
484475 assert not record .list
485476
486- # The untransformed graph should be the same as the transformed graph after
487- # replacing the `Y_rv` value variable with a transformed version of itself
488- logp_nt_fg = FunctionGraph (outputs = [logp_no_trans ], clone = False )
489- y_trans = transformed_variable (at .exp (y_vv ), y_vv )
490- y_trans .name = "y_log"
491- logp_nt_fg .replace (y_vv , y_trans )
492- logp_nt = logp_nt_fg .outputs [0 ]
493-
494- assert equal_computations ([logp_nt ], [logp_trans ])
477+ logp_fn = aesara .function ((i_vv , y_vv ), logp )
478+ logp_trans_fn = aesara .function ((i_vv_trans , y_vv_trans ), logp_trans )
479+ np .isclose (logp_trans_fn (0 , np .log (0.1 / 0.9 )), logp_fn (0 , 0.1 ))
480+ np .isclose (logp_trans_fn (1 , np .log (0.1 / 0.9 )), logp_fn (1 , 0.1 ))
495481
496482
497483def test_invalid_interval_transform ():
@@ -642,11 +628,10 @@ def test_scale_transform_rv(rv_size, scale_type):
642628def test_transformed_rv_and_value ():
643629 y_rv = at .random .halfnormal (- 1 , 1 , name = "base_rv" ) + 1
644630 y_rv .name = "y"
645- y_vv = y_rv .clone ()
646631
647- transform_rewrite = TransformValuesRewrite ({y_vv : LogTransform ()})
632+ transform_rewrite = TransformValuesRewrite ({y_rv : LogTransform ()})
648633
649- logp , _ = joint_logprob (realized = { y_rv : y_vv } , extra_rewrites = transform_rewrite )
634+ logp , ( y_vv ,) = joint_logprob (y_rv , extra_rewrites = transform_rewrite )
650635 assert_no_rvs (logp )
651636 logp_fn = aesara .function ([y_vv ], logp )
652637
0 commit comments