Skip to content

Commit 5ec481f

Browse files
committed
Make rvs_to_transforms optional in replace_rvs_by_values
1 parent 102fdc9 commit 5ec481f

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

pymc/pytensorf.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def replace_rvs_by_values(
333333
graphs: Sequence[TensorVariable],
334334
*,
335335
rvs_to_values: Dict[TensorVariable, TensorVariable],
336-
rvs_to_transforms: Dict[TensorVariable, RVTransform],
336+
rvs_to_transforms: Optional[Dict[TensorVariable, RVTransform]] = None,
337337
**kwargs,
338338
) -> List[TensorVariable]:
339339
"""Clone and replace random variables in graphs with their value variables.
@@ -346,7 +346,7 @@ def replace_rvs_by_values(
346346
The graphs in which to perform the replacements.
347347
rvs_to_values
348348
Mapping between the original graph RVs and respective value variables
349-
rvs_to_transforms
349+
rvs_to_transforms, optional
350350
Mapping between the original graph RVs and respective value transforms
351351
"""
352352

@@ -361,7 +361,8 @@ def replace_rvs_by_values(
361361
for rv, value in rvs_to_values.items():
362362
equiv_rv = equiv.get(rv, rv)
363363
equiv_rvs_to_values[equiv_rv] = equiv.get(value, value)
364-
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]
364+
if rvs_to_transforms is not None:
365+
equiv_rvs_to_transforms[equiv_rv] = rvs_to_transforms[rv]
365366

366367
def poulate_replacements(rv, replacements):
367368
# Populate replacements dict with {rv: value} pairs indicating which graph
@@ -372,14 +373,15 @@ def poulate_replacements(rv, replacements):
372373
if value is None:
373374
return []
374375

375-
transform = equiv_rvs_to_transforms.get(rv, None)
376-
if transform is not None:
377-
# We want to replace uses of the RV by the back-transformation of its value
378-
value = transform.backward(value, *rv.owner.inputs)
379-
# The value may have a less precise type than the rv. In this case
380-
# filter_variable will add a SpecifyShape to ensure they are consistent
381-
value = rv.type.filter_variable(value, allow_convert=True)
382-
value.name = rv.name
376+
if rvs_to_transforms is not None:
377+
transform = equiv_rvs_to_transforms.get(rv, None)
378+
if transform is not None:
379+
# We want to replace uses of the RV by the back-transformation of its value
380+
value = transform.backward(value, *rv.owner.inputs)
381+
# The value may have a less precise type than the rv. In this case
382+
# filter_variable will add a SpecifyShape to ensure they are consistent
383+
value = rv.type.filter_variable(value, allow_convert=True)
384+
value.name = rv.name
383385

384386
replacements[rv] = value
385387
# Also walk the graph of the value variable to make any additional

0 commit comments

Comments
 (0)