Skip to content

Commit 31ad037

Browse files
aryan26royadam2392pre-commit-ci[bot]
authored
[ENH] Add the ability to check the validity of a PAG (#100)
* Added a _proper_pag function * Added some test for legal_pag * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update examples/intro/checking_validity_of_a_pag.py --------- Signed-off-by: Aryan Roy <[email protected]> Co-authored-by: Adam Li <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 71a18e7 commit 31ad037

File tree

10 files changed

+480
-21
lines changed

10 files changed

+480
-21
lines changed

.circleci/config.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ commands:
77
- run:
88
name: Check-skip
99
command: |
10-
if [ ! -d "sktree" ]; then
10+
if [ ! -d "pywhy_graphs" ]; then
1111
echo "Build was not run due to skip, exiting job ${CIRCLE_JOB} for PR ${CIRCLE_PULL_REQUEST}."
1212
circleci-agent step halt;
1313
fi
@@ -54,7 +54,7 @@ commands:
5454
echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt
5555
if [[ $(cat merge.txt) != "" ]]; then
5656
echo "Merging $(cat merge.txt)";
57-
git remote add upstream https://github.com/neurodata/scikit-tree.git;
57+
git remote add upstream https://github.com/py-why/pywhy-graphs.git;
5858
git pull --ff-only upstream "refs/pull/$(cat merge.txt)/merge";
5959
git fetch upstream main;
6060
fi
@@ -64,7 +64,7 @@ jobs:
6464
docker:
6565
# CircleCI maintains a library of pre-built images
6666
# documented at https://circleci.com/doc/2.0/circleci-images/
67-
- image: cimg/python:3.9
67+
- image: cimg/python:3.11
6868
steps:
6969
- checkout
7070
- check-skip
@@ -96,18 +96,25 @@ jobs:
9696
name: Setup torch for pgmpy
9797
command: |
9898
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
99+
- run:
100+
name: Install dodiscover
101+
command: |
102+
git clone https://github.com/py-why/dodiscover.git
103+
cd dodiscover
104+
python -m pip install .
99105
- run:
100106
name: Check installation
101107
command: |
102108
python -c "import pywhy_graphs;"
103109
python -c "import numpy; numpy.show_config()"
110+
python -c "import dodiscover;"
104111
LIBGL_DEBUG=verbose python -c "import matplotlib.pyplot as plt; plt.figure()"
105112
106113
# dowhy currently requires an older version of numpy
107-
- run:
108-
name: Temporary Hack for numpy
109-
command: |
110-
python -m pip install numpy==1.22.0
114+
# - run:
115+
# name: Temporary Hack for numpy
116+
# command: |
117+
# python -m pip install numpy==1.22.0
111118

112119
- run:
113120
name: Build documentation
Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
1+
name: CircleCI artifacts redirector
12
on: [status]
3+
4+
# Restrict the permissions granted to the use of secrets.GITHUB_TOKEN in this
5+
# github actions workflow:
6+
# https://docs.github.com/en/actions/security-guides/automatic-token-authentication
7+
permissions: read-all
8+
29
jobs:
310
circleci_artifacts_redirector_job:
4-
if: "${{ startsWith(github.event.context, 'ci/circleci: build_doc') }}"
511
runs-on: ubuntu-20.04
12+
if: "github.repository == 'py-why/pywhy-graphs' && github.event.context == 'ci/circleci: build_doc'"
13+
permissions:
14+
statuses: write
615
name: Run CircleCI artifacts redirector
716
steps:
817
- name: GitHub Action step
918
uses: larsoner/circleci-artifacts-redirector-action@master
1019
with:
1120
repo-token: ${{ secrets.GITHUB_TOKEN }}
21+
api-token: ${{ secrets.CIRCLECI_TOKEN }}
1222
artifact-path: 0/dev/index.html
1323
circleci-jobs: build_doc
1424
job-title: Check the rendered docs here!
25+
26+
- name: Check the URL
27+
if: github.event.status != 'pending'
28+
run: |
29+
curl --fail ${{ steps.step1.outputs.url }} | grep $GITHUB_SHA

.github/workflows/main.yml

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,26 @@ jobs:
117117
with:
118118
python-version: ${{ matrix.python-version }}
119119
architecture: "x64"
120+
121+
- name: Setup torch for pgmpy
122+
if: "matrix.os == 'ubuntu'"
123+
shell: bash
124+
run: |
125+
sudo apt-get update
126+
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
127+
120128
- name: Install packages via pip
121129
run: |
122-
pip install --upgrade pip
123-
pip install numpy scipy networkx statsmodels
124-
pip install .[test]
130+
python -m pip install --upgrade pip
131+
python -m pip install numpy scipy networkx statsmodels
132+
python -m pip install .[test]
133+
134+
- name: Install DoDiscover (main)
135+
run: |
136+
git clone https://github.com/py-why/dodiscover.git
137+
cd dodiscover
138+
python -m pip install .
139+
125140
- name: Install Networkx (main)
126141
if: "matrix.networkx == 'main'"
127142
run: |
@@ -131,16 +146,10 @@ jobs:
131146
pip install .[default]
132147
# pip install --progress-bar off git+https://github.com/networkx/networkx
133148
134-
- name: Setup torch for pgmpy
135-
if: "matrix.os == 'ubuntu'"
136-
shell: bash
137-
run: |
138-
sudo apt-get update
139-
sudo apt-get install nvidia-cuda-toolkit nvidia-cuda-toolkit-gcc
140-
141149
- name: Run pytest # headless via Xvfb on linux
142150
run: |
143151
pytest --cov pywhy_graphs ./pywhy_graphs
152+
144153
- name: Upload coverage stats to codecov
145154
if: ${{ matrix.os == 'ubuntu' && matrix.python-version == '3.11' && matrix.networkx == 'stable' }}
146155
uses: codecov/codecov-action@v4

doc/reference/algorithms/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Core Algorithms
2525
possible_descendants
2626
discriminating_path
2727
is_definite_noncollider
28+
valid_pag
29+
mag_to_pag
30+
pag_to_mag
31+
check_pag_definition
2832

2933
.. currentmodule:: pywhy_graphs.networkx
3034

doc/whats_new/v0.2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Changelog
3232
- |Feature| Implement a suite of functions for finding and checking semi-directed paths on a mixed-edge graph, by `Adam Li`_ (:pr:`101`)
3333
- |Feature| Implement functions for converting between a DAG and PDAG and CPDAG for generating consistent extensions of a CPDAG for example. These functions are :func:`pywhy_graphs.algorithms.pdag_to_cpdag`, :func:`pywhy_graphs.algorithms.pdag_to_dag` and :func:`pywhy_graphs.algorithms.dag_to_cpdag`, by `Adam Li`_ (:pr:`102`)
3434
- |API| Remove poetry based setup, by `Adam Li`_ (:pr:`110`)
35+
- |Feature| Implement and test function to validate PAG, by `Aryan Roy`_ (:pr:`100`)
3536

3637
Code and Documentation Contributors
3738
-----------------------------------
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
===========================
3+
On PAGs and their validity
4+
===========================
5+
6+
A PAG or a Partial Ancestral Graph is a type of mixed edge
7+
graph that can represent, in a single graph, the causal relationship
8+
between several nodes as defined by an equivalence class of MAGs.
9+
10+
PAGs account for possible unobserved confounding and selection bias
11+
in the underlying equivalence class of SCMs.
12+
13+
Another way to understand this is that PAGs encode conditional independence
14+
constraints stemming from Causal Graphs. Since these constraints do not lead to a
15+
unique graph, a PAG, in essence, represents a class of graphs that encode
16+
the same conditional independence constraints.
17+
18+
PAGs model this relationship by displaying all common edge marks (tail and arrowhead) shared
19+
by all members in the equivalence class and displaying circle endpoints for those marks
20+
that are not common. That is, a circular endpoint (``*-o``) can represent both a directed
21+
(``*->``) and tail (``*—``) endpoint in causal graphs within the equivalence class.
22+
23+
More details on PAGs can be found at :footcite:`Zhang2008`.
24+
25+
"""
26+
27+
import pywhy_graphs
28+
from pywhy_graphs.viz import draw
29+
from pywhy_graphs import PAG
30+
31+
try:
32+
from dodiscover import FCI, make_context
33+
from dodiscover.ci import Oracle
34+
from dodiscover.constraint.utils import dummy_sample
35+
except ImportError as e:
36+
raise ImportError("The 'dodiscover' package is required to convert a MAG to a PAG.")
37+
38+
39+
# %%
40+
# PAGs in pywhy-graphs
41+
# ---------------------------
42+
# Constructing a PAG in pywhy-graphs is an easy task since
43+
# the library provides a separate class for this purpose.
44+
# True to the definition of PAGs, the class can contain
45+
# directed edges, bidirected edges, undirected edges and
46+
# cicle edges. To illustrate this, we construct an example PAG
47+
# as described in :footcite:`Zhang2008`, figure 4:
48+
49+
pag = PAG()
50+
pag.add_edge("I", "S", pag.directed_edge_name)
51+
pag.add_edge("G", "S", pag.directed_edge_name)
52+
pag.add_edge("G", "L", pag.directed_edge_name)
53+
pag.add_edge("S", "L", pag.directed_edge_name)
54+
pag.add_edge("PSH", "S", pag.directed_edge_name)
55+
pag.add_edge("S", "PSH", pag.circle_edge_name)
56+
pag.add_edge("S", "G", pag.circle_edge_name)
57+
pag.add_edge("S", "I", pag.circle_edge_name)
58+
59+
60+
# Finally, the graph looks like this:
61+
dot_graph = draw(pag)
62+
dot_graph.render(outfile="valid_pag.png", view=True)
63+
64+
65+
# %%
66+
# Validity of a PAG
67+
# ---------------------------
68+
# For a PAG to be valid, it must represent a valid
69+
# equivalent class of MAGs. This can be verified by
70+
# turning the PAG into an MAG and then checking the
71+
# validity of the MAG.
72+
# Theorem 2 in :footcite:`Zhang2008` provides a method for checking the validity of a PAG.
73+
# To check if the constructed PAG is a valid one in pywhy-graphs, we can simply do:
74+
75+
76+
# returns True
77+
print(pywhy_graphs.valid_pag(pag))
78+
79+
# %%
80+
# If we want to test whether this algorithm is working correctly or not, we can change
81+
# a single mark in the graph such that the PAG. By removing a circle edge, we are removing
82+
# the representation of multiple marks as encoded by the different MAGs this PAG represents.
83+
# In this specific case, by removing the circle endpoint ``S *-o I``, we are saying that ``S``
84+
# directly causes ``I``. However, there is no way of determining this using the FCI logical rules.
85+
# One would not be able to determine that the adjacency is due to a direct
86+
# causal relationship (directed edge), confounded relationship (bidirected edge), or an inducing path
87+
# relationship. As such, the resulting graph is no longer a valid PAG.
88+
89+
pag.remove_edge("S", "I", pag.circle_edge_name)
90+
91+
# returns False
92+
print(pywhy_graphs.valid_pag(pag))
93+
94+
# %%
95+
# References
96+
# ----------
97+
# .. footbibliography::

examples/intro/intro_causal_graphs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ def clone(self):
136136
# Using the graph, we can explore d-separation statements, which by the Markov
137137
# condition, imply conditional independences.
138138
# For example, 'z' is d-separated from 'x' because of the collider at 'y'
139-
print(f"'z' is d-separated from 'x': {nx.is_d_separator(G, {'z'}, {'x'}, set())}")
139+
print(f"'z' is d-separated from 'x': {nx.d_separated(G, {'z'}, {'x'}, set())}")
140140

141141
# Conditioning on the collider, opens up the path
142-
print(f"'z' is d-separated from 'x' given 'y': {nx.is_d_separator(G, {'z'}, {'x'}, {'y'})}")
142+
print(f"'z' is d-separated from 'x' given 'y': {nx.d_separated(G, {'z'}, {'x'}, {'y'})}")
143143

144144
# %%
145145
# Acyclic Directed Mixed Graphs (ADMG)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ doc = [
5252
'graphviz',
5353
'pygraphviz',
5454
'pgmpy',
55+
'dowhy',
5556
]
5657
style = [
5758
"pre-commit",

0 commit comments

Comments
 (0)