Skip to content

Commit e209cbe

Browse files
Merge branch 'main' into pylint_refactoring
# Conflicts: # causal_testing/testing/causal_test_engine.py # tests/testing_tests/test_causal_test_outcome.py
2 parents 9066669 + 65765cb commit e209cbe

File tree

19 files changed

+322
-150
lines changed

19 files changed

+322
-150
lines changed

.github/workflows/publish-to-pypi.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ name: Publish python PyPI
33
on:
44
push:
55
tags:
6-
- v0.0.0
6+
- v*
7+
78
jobs:
89
build-release:
910
name: Build and publish PyPI
@@ -21,6 +22,8 @@ jobs:
2122
pip3 install .
2223
pip3 install .[pypi]
2324
pip3 install build
25+
pip3 install setuptools --upgrade
26+
pip3 install setuptools_scm
2427
- name: Build Package
2528
run: |
2629
python -m build --no-isolation

.github/workflows/publish-to-test-pypi.yaml

Lines changed: 0 additions & 34 deletions
This file was deleted.

causal_testing/json_front/json_class.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,12 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
199199
treatment_var = causal_test_case.treatment_variable
200200
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
201201
estimation_model = estimator(
202-
(treatment_var.name,),
203-
causal_test_case.treatment_value,
204-
causal_test_case.control_value,
205-
minimal_adjustment_set,
206-
(causal_test_case.outcome_variable.name,),
207-
causal_test_engine.scenario_execution_data_df,
202+
treatment=treatment_var.name,
203+
treatment_value=causal_test_case.treatment_value,
204+
control_value=causal_test_case.control_value,
205+
adjustment_set=minimal_adjustment_set,
206+
outcome=causal_test_case.outcome_variable.name,
207+
df=causal_test_engine.scenario_execution_data_df,
208208
effect_modifiers=causal_test_case.effect_modifier_configuration,
209209
)
210210

causal_testing/specification/causal_dag.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ def __init__(self, dot_path: str = None, **attr):
138138
if not self.is_acyclic():
139139
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
140140

141+
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
142+
"""
143+
Checks the three instrumental variable assumptions, raising a
144+
ValueError if any are violated.
145+
146+
:return Boolean True if the three IV assumptions hold.
147+
"""
148+
# (i) Instrument is associated with treatment
149+
if nx.d_separated(self.graph, {instrument}, {treatment}, set()):
150+
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151+
152+
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153+
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
154+
raise ValueError(
155+
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156+
)
157+
158+
# (iii) Instrument and outcome do not share causes
159+
if any(
160+
[
161+
cause
162+
for cause in self.graph.nodes
163+
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164+
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165+
]
166+
):
167+
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168+
169+
return True
170+
141171
def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
142172
"""Add an edge to the causal DAG.
143173

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8686

8787
for test in tests:
8888
estimator = estimator_class(
89-
(test.treatment_variable.name,),
89+
test.treatment_variable.name,
9090
test.treatment_value,
9191
test.control_value,
9292
minimal_adjustment_set,
93-
(test.outcome_variable.name,),
93+
test.outcome_variable.name,
9494
)
9595
if estimator.df is None:
9696
estimator.df = self.scenario_execution_data_df

causal_testing/testing/causal_test_result.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,22 +68,22 @@ def to_dict(self):
6868
"adjustment_set": self.adjustment_set,
6969
"test_value": self.test_value,
7070
}
71-
if self.confidence_intervals:
71+
if self.confidence_intervals and all(self.confidence_intervals):
7272
base_dict["ci_low"] = min(self.confidence_intervals)
7373
base_dict["ci_high"] = max(self.confidence_intervals)
7474
return base_dict
7575

7676
def ci_low(self):
7777
"""Return the lower bracket of the confidence intervals."""
78-
if not self.confidence_intervals:
79-
return None
80-
return min(self.confidence_intervals)
78+
if self.confidence_intervals and all(self.confidence_intervals):
79+
return min(self.confidence_intervals)
80+
return None
8181

8282
def ci_high(self):
8383
"""Return the higher bracket of the confidence intervals."""
84-
if not self.confidence_intervals:
85-
return None
86-
return max(self.confidence_intervals)
84+
if self.confidence_intervals and all(self.confidence_intervals):
85+
return max(self.confidence_intervals)
86+
return None
8787

8888
def summary(self):
8989
"""Summarise the causal test result as an intuitive sentence."""

0 commit comments

Comments
 (0)