Skip to content

Commit 868ea5d

Browse files
committed
We now support placing conditions on the data again.
1 parent b82241f commit 868ea5d

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
301301
"""Create the necessary inputs for a single test case
302302
:param causal_test_case: The concrete test case to be executed
303303
:param test: Single JSON test definition stored in a mapping (dict)
304-
:param conditions: A list of conditions which should be applied to the
305-
data. Conditions should be in the query format detailed at
306-
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
307304
:returns:
308305
- estimation_model - Estimator instance for the test being run
309306
"""
@@ -315,6 +312,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
315312
"formulas"
316313
)
317314
estimator_kwargs["formula"] = test["formula"]
315+
estimator_kwargs["query"] = test["query"] if "query" in test else ""
318316
estimator_kwargs["adjustment_set"] = None
319317
else:
320318
minimal_adjustment_set = self.causal_specification.causal_dag.identification(
@@ -328,6 +326,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
328326
estimator_kwargs["control_value"] = causal_test_case.control_value
329327
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
330328
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
329+
estimator_kwargs["df"] = self.data_collector.collect_data()
331330
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05
332331

333332
estimation_model = test["estimator"](**estimator_kwargs)

causal_testing/testing/estimators.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,25 @@ def __init__(
5050
df: pd.DataFrame = None,
5151
effect_modifiers: dict[str:Any] = None,
5252
alpha: float = 0.05,
53+
query: str = "",
5354
):
5455
self.treatment = treatment
5556
self.treatment_value = treatment_value
5657
self.control_value = control_value
5758
self.adjustment_set = adjustment_set
5859
self.outcome = outcome
59-
self.df = df
6060
self.alpha = alpha
61+
self.df = df.query(query) if query else df
62+
6163
if effect_modifiers is None:
6264
self.effect_modifiers = {}
6365
elif isinstance(effect_modifiers, dict):
6466
self.effect_modifiers = effect_modifiers
6567
else:
6668
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
6769
self.modelling_assumptions = []
70+
if query:
71+
self.modelling_assumptions.append(query)
6872
self.add_modelling_assumptions()
6973
logger.debug("Effect Modifiers: %s", self.effect_modifiers)
7074

@@ -100,8 +104,18 @@ def __init__(
100104
df: pd.DataFrame = None,
101105
effect_modifiers: dict[str:Any] = None,
102106
formula: str = None,
107+
query: str = "",
103108
):
104-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
109+
super().__init__(
110+
treatment=treatment,
111+
treatment_value=treatment_value,
112+
control_value=control_value,
113+
adjustment_set=adjustment_set,
114+
outcome=outcome,
115+
df=df,
116+
effect_modifiers=effect_modifiers,
117+
query=query,
118+
)
105119

106120
self.model = None
107121

@@ -116,13 +130,13 @@ def add_modelling_assumptions(self):
116130
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
117131
must hold if the resulting causal inference is to be considered valid.
118132
"""
119-
self.modelling_assumptions += (
133+
self.modelling_assumptions.append(
120134
"The variables in the data must fit a shape which can be expressed as a linear"
121135
"combination of parameters and functions of variables. Note that these functions"
122136
"do not need to be linear."
123137
)
124-
self.modelling_assumptions += "The outcome must be binary."
125-
self.modelling_assumptions += "Independently and identically distributed errors."
138+
self.modelling_assumptions.append("The outcome must be binary.")
139+
self.modelling_assumptions.append("Independently and identically distributed errors.")
126140

127141
def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
128142
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
@@ -291,9 +305,18 @@ def __init__(
291305
effect_modifiers: dict[Variable:Any] = None,
292306
formula: str = None,
293307
alpha: float = 0.05,
308+
query: str = "",
294309
):
295310
super().__init__(
296-
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
311+
treatment,
312+
treatment_value,
313+
control_value,
314+
adjustment_set,
315+
outcome,
316+
df,
317+
effect_modifiers,
318+
alpha=alpha,
319+
query=query,
297320
)
298321

299322
self.model = None
@@ -314,7 +337,7 @@ def add_modelling_assumptions(self):
314337
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
315338
must hold if the resulting causal inference is to be considered valid.
316339
"""
317-
self.modelling_assumptions += (
340+
self.modelling_assumptions.append(
318341
"The variables in the data must fit a shape which can be expressed as a linear"
319342
"combination of parameters and functions of variables. Note that these functions"
320343
"do not need to be linear."
@@ -468,13 +491,17 @@ def add_modelling_assumptions(self):
468491
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
469492
must hold if the resulting causal inference is to be considered valid.
470493
"""
471-
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
494+
self.modelling_assumptions.append(
495+
"""The instrument and the treatment, and the treatment and the outcome must be
472496
related linearly in the form Y = aX + b."""
473-
self.modelling_assumptions += """The three IV conditions must hold
497+
)
498+
self.modelling_assumptions.append(
499+
"""The three IV conditions must hold
474500
(i) Instrument is associated with treatment
475501
(ii) Instrument does not affect outcome except through its potential effect on treatment
476502
(iii) Instrument and outcome do not share causes
477503
"""
504+
)
478505

479506
def estimate_iv_coefficient(self, df):
480507
"""
@@ -517,7 +544,7 @@ def add_modelling_assumptions(self):
517544
518545
:return self: Update self.modelling_assumptions
519546
"""
520-
self.modelling_assumptions += "Non-parametric estimator: no restrictions imposed on the data."
547+
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")
521548

522549
def estimate_ate(self) -> float:
523550
"""Estimate the average treatment effect.

0 commit comments

Comments
 (0)