Skip to content

Commit 9abdcbc

Browse files
authored
Merge pull request #257 from CITCOM-project/conditions
We now support placing conditions on the data again.
2 parents e7a3e90 + 41bd188 commit 9abdcbc

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
301301
"""Create the necessary inputs for a single test case
302302
:param causal_test_case: The concrete test case to be executed
303303
:param test: Single JSON test definition stored in a mapping (dict)
304-
:param conditions: A list of conditions which should be applied to the
305-
data. Conditions should be in the query format detailed at
306-
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
307304
:returns:
308305
- estimation_model - Estimator instance for the test being run
309306
"""
@@ -323,11 +320,13 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
323320
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
324321
estimator_kwargs["adjustment_set"] = minimal_adjustment_set
325322

323+
estimator_kwargs["query"] = test["query"] if "query" in test else ""
326324
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
327325
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
328326
estimator_kwargs["control_value"] = causal_test_case.control_value
329327
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
330328
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
329+
estimator_kwargs["df"] = self.data_collector.collect_data()
331330
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
332331

333332
estimation_model = test["estimator"](**estimator_kwargs)

causal_testing/testing/causal_test_result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,10 @@ def push(s, inc=" "):
5959
f"Treatment value: {self.estimator.treatment_value}\n"
6060
f"Outcome: {self.estimator.outcome}\n"
6161
f"Adjustment set: {self.adjustment_set}\n"
62-
f"Formula: {self.estimator.formula}\n"
63-
f"{self.test_value.type}: {result_str}\n"
6462
)
63+
if hasattr(self.estimator, "formula"):
64+
base_str += f"Formula: {self.estimator.formula}\n"
65+
base_str += f"{self.test_value.type}: {result_str}\n"
6566
confidence_str = ""
6667
if self.confidence_intervals:
6768
ci_str = " " + str(self.confidence_intervals)

causal_testing/testing/estimators.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import statsmodels.api as sm
1111
import statsmodels.formula.api as smf
1212
from econml.dml import CausalForestDML
13-
from patsy import dmatrix
13+
from patsy import dmatrix # pylint: disable = no-name-in-module
1414

1515
from sklearn.ensemble import GradientBoostingRegressor
1616
from statsmodels.regression.linear_model import RegressionResultsWrapper
@@ -50,21 +50,25 @@ def __init__(
5050
df: pd.DataFrame = None,
5151
effect_modifiers: dict[str:Any] = None,
5252
alpha: float = 0.05,
53+
query: str = "",
5354
):
5455
self.treatment = treatment
5556
self.treatment_value = treatment_value
5657
self.control_value = control_value
5758
self.adjustment_set = adjustment_set
5859
self.outcome = outcome
59-
self.df = df
6060
self.alpha = alpha
61+
self.df = df.query(query) if query else df
62+
6163
if effect_modifiers is None:
6264
self.effect_modifiers = {}
6365
elif isinstance(effect_modifiers, dict):
6466
self.effect_modifiers = effect_modifiers
6567
else:
6668
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
6769
self.modelling_assumptions = []
70+
if query:
71+
self.modelling_assumptions.append(query)
6872
self.add_modelling_assumptions()
6973
logger.debug("Effect Modifiers: %s", self.effect_modifiers)
7074

@@ -100,8 +104,18 @@ def __init__(
100104
df: pd.DataFrame = None,
101105
effect_modifiers: dict[str:Any] = None,
102106
formula: str = None,
107+
query: str = "",
103108
):
104-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
109+
super().__init__(
110+
treatment=treatment,
111+
treatment_value=treatment_value,
112+
control_value=control_value,
113+
adjustment_set=adjustment_set,
114+
outcome=outcome,
115+
df=df,
116+
effect_modifiers=effect_modifiers,
117+
query=query,
118+
)
105119

106120
self.model = None
107121

@@ -116,13 +130,13 @@ def add_modelling_assumptions(self):
116130
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
117131
must hold if the resulting causal inference is to be considered valid.
118132
"""
119-
self.modelling_assumptions += (
133+
self.modelling_assumptions.append(
120134
"The variables in the data must fit a shape which can be expressed as a linear"
121135
"combination of parameters and functions of variables. Note that these functions"
122136
"do not need to be linear."
123137
)
124-
self.modelling_assumptions += "The outcome must be binary."
125-
self.modelling_assumptions += "Independently and identically distributed errors."
138+
self.modelling_assumptions.append("The outcome must be binary.")
139+
self.modelling_assumptions.append("Independently and identically distributed errors.")
126140

127141
def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
128142
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
@@ -291,9 +305,18 @@ def __init__(
291305
effect_modifiers: dict[Variable:Any] = None,
292306
formula: str = None,
293307
alpha: float = 0.05,
308+
query: str = "",
294309
):
295310
super().__init__(
296-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
311+
treatment,
312+
treatment_value,
313+
control_value,
314+
adjustment_set,
315+
outcome,
316+
df,
317+
effect_modifiers,
318+
alpha=alpha,
319+
query=query,
297320
)
298321

299322
self.model = None
@@ -314,7 +337,7 @@ def add_modelling_assumptions(self):
314337
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
315338
must hold if the resulting causal inference is to be considered valid.
316339
"""
317-
self.modelling_assumptions += (
340+
self.modelling_assumptions.append(
318341
"The variables in the data must fit a shape which can be expressed as a linear"
319342
"combination of parameters and functions of variables. Note that these functions"
320343
"do not need to be linear."
@@ -509,8 +532,20 @@ def __init__(
509532
df: pd.DataFrame = None,
510533
intercept: int = 1,
511534
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
535+
alpha: float = 0.05,
536+
query: str = "",
512537
):
513-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
538+
super().__init__(
539+
treatment=treatment,
540+
treatment_value=treatment_value,
541+
control_value=control_value,
542+
adjustment_set=adjustment_set,
543+
outcome=outcome,
544+
df=df,
545+
effect_modifiers=None,
546+
alpha=alpha,
547+
query=query,
548+
)
514549
self.intercept = intercept
515550
self.model = None
516551
self.instrument = instrument
@@ -520,13 +555,17 @@ def add_modelling_assumptions(self):
520555
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
521556
must hold if the resulting causal inference is to be considered valid.
522557
"""
523-
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
558+
self.modelling_assumptions.append(
559+
"""The instrument and the treatment, and the treatment and the outcome must be
524560
related linearly in the form Y = aX + b."""
525-
self.modelling_assumptions += """The three IV conditions must hold
561+
)
562+
self.modelling_assumptions.append(
563+
"""The three IV conditions must hold
526564
(i) Instrument is associated with treatment
527565
(ii) Instrument does not affect outcome except through its potential effect on treatment
528566
(iii) Instrument and outcome do not share causes
529567
"""
568+
)
530569

531570
def estimate_iv_coefficient(self, df):
532571
"""
@@ -569,7 +608,7 @@ def add_modelling_assumptions(self):
569608
570609
:return self: Update self.modelling_assumptions
571610
"""
572-
self.modelling_assumptions += "Non-parametric estimator: no restrictions imposed on the data."
611+
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")
573612

574613
def estimate_ate(self) -> float:
575614
"""Estimate the average treatment effect.

tests/testing_tests/test_estimators.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,14 @@ def test_ate_adjustment(self):
125125
logistic_regression_estimator = LogisticRegressionEstimator(
126126
"length_in", 65, 55, {"large_gauge"}, "completed", df
127127
)
128-
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
128+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
129129
self.assertEqual(round(ate, 4), -0.3388)
130130

131131
def test_ate_invalid_adjustment(self):
132132
df = self.scarf_df.copy()
133133
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
134134
with self.assertRaises(ValueError):
135-
ate, _ = logistic_regression_estimator.estimate_ate(
136-
adjustment_config = {"large_gauge": 0}
137-
)
135+
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
138136

139137
def test_ate_effect_modifiers(self):
140138
df = self.scarf_df.copy()
@@ -216,6 +214,13 @@ def setUpClass(cls) -> None:
216214
cls.nhefs_df = load_nhefs_df()
217215
cls.chapter_11_df = load_chapter_11_df()
218216

217+
def test_query(self):
218+
df = self.nhefs_df
219+
linear_regression_estimator = LinearRegressionEstimator(
220+
"treatments", None, None, set(), "outcomes", df, query="sex==1"
221+
)
222+
self.assertTrue(linear_regression_estimator.df.sex.all())
223+
219224
def test_program_11_2(self):
220225
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
221226
df = self.chapter_11_df
@@ -395,7 +400,7 @@ def test_program_15_no_interaction_ate_calculated(self):
395400
# for term_to_square in terms_to_square:
396401

397402
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
398-
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
403+
adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates}
399404
)
400405
self.assertEqual(round(ate, 1), 3.5)
401406
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])

0 commit comments

Comments
 (0)