@@ -333,7 +333,7 @@ def replace_rvs_by_values(
333
333
graphs : Sequence [TensorVariable ],
334
334
* ,
335
335
rvs_to_values : Dict [TensorVariable , TensorVariable ],
336
- rvs_to_transforms : Dict [TensorVariable , RVTransform ],
336
+ rvs_to_transforms : Optional [ Dict [TensorVariable , RVTransform ]] = None ,
337
337
** kwargs ,
338
338
) -> List [TensorVariable ]:
339
339
"""Clone and replace random variables in graphs with their value variables.
@@ -346,7 +346,7 @@ def replace_rvs_by_values(
346
346
The graphs in which to perform the replacements.
347
347
rvs_to_values
348
348
Mapping between the original graph RVs and respective value variables
349
- rvs_to_transforms
349
+ rvs_to_transforms, optional
350
350
Mapping between the original graph RVs and respective value transforms
351
351
"""
352
352
@@ -361,7 +361,8 @@ def replace_rvs_by_values(
361
361
for rv , value in rvs_to_values .items ():
362
362
equiv_rv = equiv .get (rv , rv )
363
363
equiv_rvs_to_values [equiv_rv ] = equiv .get (value , value )
364
- equiv_rvs_to_transforms [equiv_rv ] = rvs_to_transforms [rv ]
364
+ if rvs_to_transforms is not None :
365
+ equiv_rvs_to_transforms [equiv_rv ] = rvs_to_transforms [rv ]
365
366
366
367
def poulate_replacements (rv , replacements ):
367
368
# Populate replacements dict with {rv: value} pairs indicating which graph
@@ -372,14 +373,15 @@ def poulate_replacements(rv, replacements):
372
373
if value is None :
373
374
return []
374
375
375
- transform = equiv_rvs_to_transforms .get (rv , None )
376
- if transform is not None :
377
- # We want to replace uses of the RV by the back-transformation of its value
378
- value = transform .backward (value , * rv .owner .inputs )
379
- # The value may have a less precise type than the rv. In this case
380
- # filter_variable will add a SpecifyShape to ensure they are consistent
381
- value = rv .type .filter_variable (value , allow_convert = True )
382
- value .name = rv .name
376
+ if rvs_to_transforms is not None :
377
+ transform = equiv_rvs_to_transforms .get (rv , None )
378
+ if transform is not None :
379
+ # We want to replace uses of the RV by the back-transformation of its value
380
+ value = transform .backward (value , * rv .owner .inputs )
381
+ # The value may have a less precise type than the rv. In this case
382
+ # filter_variable will add a SpecifyShape to ensure they are consistent
383
+ value = rv .type .filter_variable (value , allow_convert = True )
384
+ value .name = rv .name
383
385
384
386
replacements [rv ] = value
385
387
# Also walk the graph of the value variable to make any additional
0 commit comments