@@ -305,110 +305,6 @@ def test_walk_model():
305
305
assert e in res
306
306
307
307
308
- @pytest .mark .parametrize ("symbolic_rv" , (False , True ))
309
- @pytest .mark .parametrize ("apply_transforms" , (True , False ))
310
- def test_rvs_to_value_vars (symbolic_rv , apply_transforms ):
311
-
312
- # Interval transform between last two arguments
313
- interval = Interval (bounds_fn = lambda * args : (args [- 2 ], args [- 1 ]))
314
-
315
- with pm .Model () as m :
316
- a = pm .Uniform ("a" , 0.0 , 1.0 )
317
- if symbolic_rv :
318
- raw_b = pm .Uniform .dist (0 , a + 1.0 )
319
- b = pm .Censored ("b" , raw_b , lower = 0 , upper = a + 1.0 , transform = interval )
320
- # If not True, another distribution has to be used
321
- assert isinstance (b .owner .op , SymbolicRandomVariable )
322
- else :
323
- b = pm .Uniform ("b" , 0 , a + 1.0 , transform = interval )
324
- c = pm .Normal ("c" )
325
- d = at .log (c + b ) + 2.0
326
-
327
- a_value_var = m .rvs_to_values [a ]
328
- assert a_value_var .tag .transform
329
-
330
- b_value_var = m .rvs_to_values [b ]
331
- c_value_var = m .rvs_to_values [c ]
332
-
333
- (res ,) = rvs_to_value_vars ((d ,), apply_transforms = apply_transforms )
334
-
335
- assert res .owner .op == at .add
336
- log_output = res .owner .inputs [0 ]
337
- assert log_output .owner .op == at .log
338
- log_add_output = res .owner .inputs [0 ].owner .inputs [0 ]
339
- assert log_add_output .owner .op == at .add
340
- c_output = log_add_output .owner .inputs [0 ]
341
-
342
- # We make sure that the random variables were replaced
343
- # with their value variables
344
- assert c_output == c_value_var
345
- b_output = log_add_output .owner .inputs [1 ]
346
- # When transforms are applied, the input is the back-transformation of the value_var,
347
- # otherwise it is the value_var itself
348
- if apply_transforms :
349
- assert b_output != b_value_var
350
- else :
351
- assert b_output == b_value_var
352
-
353
- res_ancestors = list (walk_model ((res ,)))
354
- res_rv_ancestors = [
355
- v for v in res_ancestors if v .owner and isinstance (v .owner .op , RandomVariable )
356
- ]
357
-
358
- # There shouldn't be any `RandomVariable`s in the resulting graph
359
- assert len (res_rv_ancestors ) == 0
360
- assert b_value_var in res_ancestors
361
- assert c_value_var in res_ancestors
362
- # When transforms are used, `d` depends on `a` through the back-transformation of
363
- # `b`, otherwise there is no direct connection between `d` and `a`
364
- if apply_transforms :
365
- assert a_value_var in res_ancestors
366
- else :
367
- assert a_value_var not in res_ancestors
368
-
369
-
370
- def test_rvs_to_value_vars_nested ():
371
- # Test that calling rvs_to_value_vars in models with nested transformations
372
- # does not change the original rvs in place. See issue #5172
373
- with pm .Model () as m :
374
- one = pm .LogNormal ("one" , mu = 0 )
375
- two = pm .LogNormal ("two" , mu = at .log (one ))
376
-
377
- # We add potentials or deterministics that are not in topological order
378
- pm .Potential ("two_pot" , two )
379
- pm .Potential ("one_pot" , one )
380
-
381
- before = aesara .clone_replace (m .free_RVs )
382
-
383
- # This call would change the model free_RVs in place in #5172
384
- res = rvs_to_value_vars (m .potentials , apply_transforms = True )
385
-
386
- after = aesara .clone_replace (m .free_RVs )
387
-
388
- assert equal_computations (before , after )
389
-
390
-
391
- def test_rvs_to_value_vars_unvalued_rv ():
392
- with pm .Model () as m :
393
- x = pm .Normal ("x" )
394
- y = pm .Normal .dist (x )
395
- z = pm .Normal ("z" , y )
396
- out = z + y
397
-
398
- x_value = m .rvs_to_values [x ]
399
- z_value = m .rvs_to_values [z ]
400
-
401
- (res ,) = rvs_to_value_vars ((out ,))
402
-
403
- assert res .owner .op == at .add
404
- assert res .owner .inputs [0 ] is z_value
405
- res_y = res .owner .inputs [1 ]
406
- # Graph should have be cloned, and therefore y and res_y should have different ids
407
- assert res_y is not y
408
- assert res_y .owner .op == at .random .normal
409
- assert res_y .owner .inputs [3 ] is x_value
410
-
411
-
412
308
class TestCompilePyMC :
413
309
def test_check_bounds_flag (self ):
414
310
"""Test that CheckParameterValue Ops are replaced or removed when using compile_pymc"""
@@ -633,3 +529,106 @@ def test_constant_fold_raises():
633
529
634
530
res = constant_fold ((y , y .shape ), raise_not_constant = False )
635
531
assert tuple (res [1 ].eval ()) == (5 ,)
532
+
533
+
534
+ class TestReplaceRVsByValues :
535
+ @pytest .mark .parametrize ("symbolic_rv" , (False , True ))
536
+ @pytest .mark .parametrize ("apply_transforms" , (True , False ))
537
+ def test_basic (self , symbolic_rv , apply_transforms ):
538
+
539
+ # Interval transform between last two arguments
540
+ interval = Interval (bounds_fn = lambda * args : (args [- 2 ], args [- 1 ]))
541
+
542
+ with pm .Model () as m :
543
+ a = pm .Uniform ("a" , 0.0 , 1.0 )
544
+ if symbolic_rv :
545
+ raw_b = pm .Uniform .dist (0 , a + 1.0 )
546
+ b = pm .Censored ("b" , raw_b , lower = 0 , upper = a + 1.0 , transform = interval )
547
+ # If not True, another distribution has to be used
548
+ assert isinstance (b .owner .op , SymbolicRandomVariable )
549
+ else :
550
+ b = pm .Uniform ("b" , 0 , a + 1.0 , transform = interval )
551
+ c = pm .Normal ("c" )
552
+ d = at .log (c + b ) + 2.0
553
+
554
+ a_value_var = m .rvs_to_values [a ]
555
+ assert a_value_var .tag .transform
556
+
557
+ b_value_var = m .rvs_to_values [b ]
558
+ c_value_var = m .rvs_to_values [c ]
559
+
560
+ (res ,) = rvs_to_value_vars ((d ,), apply_transforms = apply_transforms )
561
+
562
+ assert res .owner .op == at .add
563
+ log_output = res .owner .inputs [0 ]
564
+ assert log_output .owner .op == at .log
565
+ log_add_output = res .owner .inputs [0 ].owner .inputs [0 ]
566
+ assert log_add_output .owner .op == at .add
567
+ c_output = log_add_output .owner .inputs [0 ]
568
+
569
+ # We make sure that the random variables were replaced
570
+ # with their value variables
571
+ assert c_output == c_value_var
572
+ b_output = log_add_output .owner .inputs [1 ]
573
+ # When transforms are applied, the input is the back-transformation of the value_var,
574
+ # otherwise it is the value_var itself
575
+ if apply_transforms :
576
+ assert b_output != b_value_var
577
+ else :
578
+ assert b_output == b_value_var
579
+
580
+ res_ancestors = list (walk_model ((res ,)))
581
+ res_rv_ancestors = [
582
+ v for v in res_ancestors if v .owner and isinstance (v .owner .op , RandomVariable )
583
+ ]
584
+
585
+ # There shouldn't be any `RandomVariable`s in the resulting graph
586
+ assert len (res_rv_ancestors ) == 0
587
+ assert b_value_var in res_ancestors
588
+ assert c_value_var in res_ancestors
589
+ # When transforms are used, `d` depends on `a` through the back-transformation of
590
+ # `b`, otherwise there is no direct connection between `d` and `a`
591
+ if apply_transforms :
592
+ assert a_value_var in res_ancestors
593
+ else :
594
+ assert a_value_var not in res_ancestors
595
+
596
+ def test_unvalued_rv (self ):
597
+ with pm .Model () as m :
598
+ x = pm .Normal ("x" )
599
+ y = pm .Normal .dist (x )
600
+ z = pm .Normal ("z" , y )
601
+ out = z + y
602
+
603
+ x_value = m .rvs_to_values [x ]
604
+ z_value = m .rvs_to_values [z ]
605
+
606
+ (res ,) = rvs_to_value_vars ((out ,))
607
+
608
+ assert res .owner .op == at .add
609
+ assert res .owner .inputs [0 ] is z_value
610
+ res_y = res .owner .inputs [1 ]
611
+ # Graph should have be cloned, and therefore y and res_y should have different ids
612
+ assert res_y is not y
613
+ assert res_y .owner .op == at .random .normal
614
+ assert res_y .owner .inputs [3 ] is x_value
615
+
616
+ def test_no_change_inplace (self ):
617
+ # Test that calling rvs_to_value_vars in models with nested transformations
618
+ # does not change the original rvs in place. See issue #5172
619
+ with pm .Model () as m :
620
+ one = pm .LogNormal ("one" , mu = 0 )
621
+ two = pm .LogNormal ("two" , mu = at .log (one ))
622
+
623
+ # We add potentials or deterministics that are not in topological order
624
+ pm .Potential ("two_pot" , two )
625
+ pm .Potential ("one_pot" , one )
626
+
627
+ before = aesara .clone_replace (m .free_RVs )
628
+
629
+ # This call would change the model free_RVs in place in #5172
630
+ res = rvs_to_value_vars (m .potentials , apply_transforms = True )
631
+
632
+ after = aesara .clone_replace (m .free_RVs )
633
+
634
+ assert equal_computations (before , after )
0 commit comments