Skip to content

Commit d0face4

Browse files
committed
Better guesses for why logp has RVs
1 parent 8ebb61e commit d0face4

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

pymc/logprob/basic.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> Tens
380380
return expr
381381

382382

383-
RVS_IN_JOINT_LOGP_GRAPH_MSG = (
384-
"Random variables detected in the logp graph: %s.\n"
385-
"This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n"
386-
"or when not all rvs have a corresponding value variable."
387-
)
388-
389-
390383
def conditional_logp(
391384
rv_values: dict[TensorVariable, TensorVariable],
392385
warn_rvs=True,
@@ -541,7 +534,11 @@ def conditional_logp(
541534
if warn_rvs:
542535
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
543536
if rvs_in_logp_expressions:
544-
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
537+
warnings.warn(
538+
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
539+
"This can happen when not all random variables have a corresponding value variable.",
540+
UserWarning,
541+
)
545542

546543
return values_to_logprobs
547544

@@ -589,6 +586,10 @@ def transformed_conditional_logp(
589586

590587
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
591588
if rvs_in_logp_expressions:
592-
raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
589+
raise ValueError(
590+
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
591+
"This can happen when mixing variables from different models, "
592+
"or when CustomDist logp or Interval transform functions reference nonlocal variables."
593+
)
593594

594595
return logp_terms_list

0 commit comments

Comments
 (0)