-
Notifications
You must be signed in to change notification settings - Fork 70
Description
Hey team — following up on pymc-devs/pymc#6716.
The main idea is to allow users to derive a quick causal-style DAG from a PyMC model’s Graphviz diagram (the DPG, i.e., the directed probabilistic/model graph from pm.model_to_graphviz
). This is feasible under clear assumptions and with explicit limitations.
Different regression specifications can correspond to the same underlying causal DAG. For example, in a simple chain A → B → C
, depending on the estimand one might fit C ~ B + ε
or C ~ A + ε
. Both are compatible with the same causal story, yet pm.model_to_graphviz
will (rightly) produce different model graphs for those different PyMC specifications.
What this is (and isn’t)
-
Is: a visualization helper that maps a PyMC model graph to a compact, causal-style DAG by:
- keeping only
pm.Data
nodes, - optionally showing selected unobserved modeled effects (e.g., priors) as dashed ellipses,
- drawing edges to the first data node(s) encountered along any directed path (don’t traverse past a data node).
- keeping only
-
Is not: causal discovery or identification. It does not infer causal structure; it presents a causal-style view implied by a specific PyMC model.
Using @drbenvincent’s blog post “Causal inference: have you been doing science wrong all this time?” as inspiration, I'll add some examples given the following DAG.

Example 1: Adjustment model Y ~ X + Q
Suppose we care about the effect of X
on Y
, given the causal DAG we must adjust for confounder Q
:
with pm.Model() as x_y_model:
_Q = pm.Data("_Q", df["Q"])
_X = pm.Data("_X", df["X"])
_Y = pm.Data("_Y", df["Y"])
beta_q = pm.Normal("beta_q")
beta_x = pm.Normal("beta_x")
sigma_y = pm.HalfNormal("sigma")
Y = pm.Normal("Y", mu=beta_x * _X + beta_q * _Q, sigma=sigma_y, observed=_Y)
causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(x_y_model).source)
Result (causal-style): _X → _Y
and _Q → _Y
— exactly what we expect for the adjustment set. 😕

Note
If we want to capture other effects, such as Q over Y or Q over P. We could make more regressions with other adjustments sets and this function will bring a C-DAG different for each (basically, each DPG maps a C-DAG style).
Example 2: Fully specified “luxury” model
If we know the true causal DAG and build a fully specified PyMC model that mirrors it, the helper reproduces that DAG:
with pm.Model() as full_luxury_model:
_Q = pm.Data("_Q", df["Q"])
_X = pm.Data("_X", df["X"])
_Y = pm.Data("_Y", df["Y"])
_P = pm.Data("_P", df["P"])
# slopes
qx = pm.Normal("qx") # X ~ Q
xy = pm.Normal("xy") # Y ~ X
qy = pm.Normal("qy") # Y ~ Q
xp = pm.Normal("xp") # P ~ X
yp = pm.Normal("yp") # P ~ Y
# scales
sigma_x = pm.HalfNormal("sigma_x")
sigma_y = pm.HalfNormal("sigma_y")
sigma_p = pm.HalfNormal("sigma_p")
Q = pm.Normal("Q", observed=_Q)
X = pm.Normal("X", mu=qx * _Q, sigma=sigma_x, observed=_X)
Y = pm.Normal("Y", mu=xy * _X + qy * _Q, sigma=sigma_y, observed=_Y)
P = pm.Normal("P", mu=xp * _X + yp * _Y, sigma=sigma_p, observed=_P)
causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(full_luxury_model).source)
Output: The original causal diagram 🔥

Proposed API
pymc_dpg_to_causal_dag(
model_or_dot, # pm.Model | graphviz.Digraph | DOT string
*,
first_hit_only=True, # connect to closest downstream Data
node_style='style="filled"', # Data node style
unobserved_vars=None, # e.g. ["intercept", "sigma"]; dashed ellipses
) -> graphviz.Source # render with .render() or display in notebooks
- First-hit rule: From each source node, walk forward; when you hit any pm.Data node, draw an edge and stop (don’t traverse beyond that data node). This avoids dense transitive edges and respects mediators (e.g., _Q → _Y, not _Q → _P via _Y).
- Unobserved variables: If provided, render the listed node IDs (that exist in the graph) as dashed ellipses and connect them to their closest downstream data nodes.
Note
We can add this to pymc extras, check how community use it, then decide if we want to be in the main PyMC repo.
Limitations & guidance
- This helper does not discover causal structure or validate identification.
- Multiple PyMC model specifications that implement valid adjustment sets can render to the same or different causal-style DAG — that’s intended. (The relationship will be 1:1 with DPG)
- Best used when the causal DAG is known and you want to check the PyMC model mirrors it; communicate the causal story more cleanly than the full probabilistic graph.
If there’s interest, I’m happy to open a PR adding this as a documented recipe (with tests) or a small utility in an examples module. Check my draft in Google Colab.
cc: @drbenvincent @jessegrabowski @ricardoV94 @cluhmann @twiecki