@@ -380,13 +380,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens
380380 return expr
381381
382382
383- RVS_IN_JOINT_LOGP_GRAPH_MSG = (
384- "Random variables detected in the logp graph: %s.\n "
385- "This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n "
386- "or when not all rvs have a corresponding value variable."
387- )
388-
389-
390383def conditional_logp (
391384 rv_values : dict [TensorVariable , TensorVariable ],
392385 warn_rvs = True ,
@@ -541,7 +534,11 @@ def conditional_logp(
541534 if warn_rvs :
542535 rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logprobs )
543536 if rvs_in_logp_expressions :
544- warnings .warn (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions , UserWarning )
537+ warnings .warn (
538+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
539+ "This can happen when not all random variables have a corresponding value variable." ,
540+ UserWarning ,
541+ )
545542
546543 return values_to_logprobs
547544
@@ -589,6 +586,10 @@ def transformed_conditional_logp(
589586
590587 rvs_in_logp_expressions = _find_unallowed_rvs_in_graph (logp_terms_list )
591588 if rvs_in_logp_expressions :
592- raise ValueError (RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions )
589+ raise ValueError (
590+ f"Random variables detected in the logp graph: { rvs_in_logp_expressions } .\n "
591+ "This can happen when mixing variables from different models, "
592+ "or when CustomDist logp or Interval transform functions reference nonlocal variables."
593+ )
593594
594595 return logp_terms_list
0 commit comments