Skip to content

Creating function to transform Graphviz DPG model into Causal DAG #572

@cetagostini

Description

@cetagostini

Hey team — following up on pymc-devs/pymc#6716.

The main idea is to allow users to derive a quick causal-style DAG from a PyMC model’s Graphviz diagram (the DPG, i.e., the directed probabilistic/model graph from pm.model_to_graphviz). This is feasible under clear assumptions and with explicit limitations.

Different regression specifications can correspond to the same underlying causal DAG. For example, in a simple chain A → B → C, depending on the estimand one might fit C ~ B + ε or C ~ A + ε. Both are compatible with the same causal story, yet pm.model_to_graphviz will (rightly) produce different model graphs for those different PyMC specifications.

What this is (and isn’t)

  • Is: a visualization helper that maps a PyMC model graph to a compact, causal-style DAG by:

    • keeping only pm.Data nodes,
    • optionally showing selected unobserved modeled effects (e.g., priors) as dashed ellipses,
    • drawing edges to the first data node(s) encountered along any directed path (don’t traverse past a data node).
  • Is not: causal discovery or identification. It does not infer causal structure; it presents a causal-style view implied by a specific PyMC model.

Using @drbenvincent’s blog post “Causal inference: have you been doing science wrong all this time?” as inspiration, I'll add some examples given the following DAG.

Image

Example 1: Adjustment model Y ~ X + Q

Suppose we care about the effect of X on Y, given the causal DAG we must adjust for confounder Q:

with pm.Model() as x_y_model:
    _Q = pm.Data("_Q", df["Q"])
    _X = pm.Data("_X", df["X"])
    _Y = pm.Data("_Y", df["Y"])

    beta_q = pm.Normal("beta_q")
    beta_x = pm.Normal("beta_x")
    sigma_y = pm.HalfNormal("sigma")

    Y = pm.Normal("Y", mu=beta_x * _X + beta_q * _Q, sigma=sigma_y, observed=_Y)

causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(x_y_model).source)

Result (causal-style): _X → _Y and _Q → _Y — exactly what we expect for the adjustment set. 😕

Image

Note

If we want to capture other effects, such as Q over Y or Q over P. We could make more regressions with other adjustments sets and this function will bring a C-DAG different for each (basically, each DPG maps a C-DAG style).

Example 2: Fully specified “luxury” model

If we know the true causal DAG and build a fully specified PyMC model that mirrors it, the helper reproduces that DAG:

with pm.Model() as full_luxury_model:
    _Q = pm.Data("_Q", df["Q"])
    _X = pm.Data("_X", df["X"])
    _Y = pm.Data("_Y", df["Y"])
    _P = pm.Data("_P", df["P"])

    # slopes
    qx = pm.Normal("qx")      # X ~ Q
    xy = pm.Normal("xy")      # Y ~ X
    qy = pm.Normal("qy")      # Y ~ Q
    xp = pm.Normal("xp")      # P ~ X
    yp = pm.Normal("yp")      # P ~ Y

    # scales
    sigma_x = pm.HalfNormal("sigma_x")
    sigma_y = pm.HalfNormal("sigma_y")
    sigma_p = pm.HalfNormal("sigma_p")

    Q = pm.Normal("Q", observed=_Q)
    X = pm.Normal("X", mu=qx * _Q,                 sigma=sigma_x, observed=_X)
    Y = pm.Normal("Y", mu=xy * _X + qy * _Q,       sigma=sigma_y, observed=_Y)
    P = pm.Normal("P", mu=xp * _X + yp * _Y,       sigma=sigma_p, observed=_P)

causal_src = pymc_dpg_to_causal_dag(pm.model_to_graphviz(full_luxury_model).source)

Output: The original causal diagram 🔥

Image

Proposed API

pymc_dpg_to_causal_dag(
    model_or_dot,                    # pm.Model | graphviz.Digraph | DOT string
    *,
    first_hit_only=True,             # connect to closest downstream Data
    node_style='style="filled"',     # Data node style
    unobserved_vars=None,            # e.g. ["intercept", "sigma"]; dashed ellipses
) -> graphviz.Source                 # render with .render() or display in notebooks
  • First-hit rule: From each source node, walk forward; when you hit any pm.Data node, draw an edge and stop (don’t traverse beyond that data node). This avoids dense transitive edges and respects mediators (e.g., _Q → _Y, not _Q → _P via _Y).
  • Unobserved variables: If provided, render the listed node IDs (that exist in the graph) as dashed ellipses and connect them to their closest downstream data nodes.

Note

We can add this to pymc extras, check how community use it, then decide if we want to be in the main PyMC repo.

Limitations & guidance

  • This helper does not discover causal structure or validate identification.
  • Multiple PyMC model specifications that implement valid adjustment sets can render to the same or different causal-style DAG — that’s intended. (The relationship will be 1:1 with DPG)
  • Best used when the causal DAG is known and you want to check the PyMC model mirrors it; communicate the causal story more cleanly than the full probabilistic graph.

If there’s interest, I’m happy to open a PR adding this as a documented recipe (with tests) or a small utility in an examples module. Check my draft in Google Colab.

cc: @drbenvincent @jessegrabowski @ricardoV94 @cluhmann @twiecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions