@@ -380,13 +380,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens
380
380
return expr
381
381
382
382
383
- RVS_IN_JOINT_LOGP_GRAPH_MSG = (
384
- "Random variables detected in the logp graph: %s.\n "
385
- "This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n "
386
- "or when not all rvs have a corresponding value variable."
387
- )
388
-
389
-
390
383
def conditional_logp (
391
384
rv_values : dict [TensorVariable , TensorVariable ],
392
385
warn_rvs = True ,
@@ -541,7 +534,11 @@ def conditional_logp(
541
534
if warn_rvs :
542
535
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logprobs )
543
536
if rvs_in_logp_expressions :
544
- warnings .warn (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions , UserWarning )
537
+ warnings .warn (
538
+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
539
+ "This can happen when not all random variables have a corresponding value variable." ,
540
+ UserWarning ,
541
+ )
545
542
546
543
return values_to_logprobs
547
544
@@ -589,6 +586,10 @@ def transformed_conditional_logp(
589
586
590
587
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logp_terms_list )
591
588
if rvs_in_logp_expressions :
592
- raise ValueError (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions )
589
+ raise ValueError (
590
+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
591
+ "This can happen when mixing variables from different models, "
592
+ "or when CustomDist logp or Interval transform functions reference nonlocal variables."
593
+ )
593
594
594
595
return logp_terms_list
0 commit comments