Skip to content

Commit 7ad5827

Browse files
authored
Merge pull request #207 from CITCOM-project/alpha
Alpha
2 parents 3be14bc + 841e601 commit 7ad5827

File tree

7 files changed

+25
-22
lines changed

7 files changed

+25
-22
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def _setup_test(
313313
"outcome": causal_test_case.outcome_variable.name,
314314
"df": causal_test_engine.scenario_execution_data_df,
315315
"effect_modifiers": causal_test_case.effect_modifier_configuration,
316+
"alpha": test["alpha"] if "alpha" in test else 0.05,
316317
}
317318
if "formula" in test:
318319
estimator_kwargs["formula"] = test["formula"]

causal_testing/specification/metamorphic_relation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def to_json_stub(self, skip=True) -> dict:
181181
"mutations": [self.treatment_var],
182182
"expected_effect": {self.output_var: "NoEffect"},
183183
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
184+
"alpha": 0.05,
184185
"skip": skip,
185186
}
186187

causal_testing/testing/causal_test_outcome.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ def apply(self, res: CausalTestResult) -> bool:
5151
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
5252
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
5353
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
54+
55+
if not all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)):
56+
print(
57+
"FAILING ON",
58+
[(ci_low, ci_high) for ci_low, ci_high in zip(ci_low, ci_high) if not ci_low < 0 < ci_high],
59+
)
60+
5461
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
5562
abs(v) < self.atol for v in value
5663
)

causal_testing/testing/causal_test_result.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def push(s, inc=" "):
5656
f"Treatment value: {self.estimator.treatment_value}\n"
5757
f"Outcome: {self.estimator.outcome}\n"
5858
f"Adjustment set: {self.adjustment_set}\n"
59+
f"Formula: {self.estimator.formula}\n"
5960
f"{self.test_value.type}: {result_str}\n"
6061
)
6162
confidence_str = ""
@@ -64,6 +65,7 @@ def push(s, inc=" "):
6465
if "\n" in ci_str:
6566
ci_str = " " + push(pd.DataFrame(self.confidence_intervals).transpose().to_string(header=False))
6667
confidence_str += f"Confidence intervals:{ci_str}\n"
68+
confidence_str += f"Alpha:{self.estimator.alpha}\n"
6769
return base_str + confidence_str
6870

6971
def to_dict(self):

causal_testing/testing/estimators.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ def __init__(
4949
outcome: str,
5050
df: pd.DataFrame = None,
5151
effect_modifiers: dict[str:Any] = None,
52+
alpha: float = 0.05,
5253
):
5354
self.treatment = treatment
5455
self.treatment_value = treatment_value
5556
self.control_value = control_value
5657
self.adjustment_set = adjustment_set
5758
self.outcome = outcome
5859
self.df = df
60+
self.alpha = alpha
5961
if effect_modifiers is None:
6062
self.effect_modifiers = {}
6163
elif isinstance(effect_modifiers, dict):
@@ -237,7 +239,7 @@ def estimate_ate(self, estimator_params: dict = None) -> float:
237239
return estimate, (None, None)
238240

239241
bootstraps = sorted(list(treatment_bootstraps - control_bootstraps))
240-
bound = int((bootstrap_size * 0.05) / 2)
242+
bound = int((bootstrap_size * self.alpha) / 2)
241243
ci_low = bootstraps[bound]
242244
ci_high = bootstraps[bootstrap_size - bound]
243245

@@ -271,7 +273,7 @@ def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
271273
return estimate, (None, None)
272274

273275
bootstraps = sorted(list(treatment_bootstraps / control_bootstraps))
274-
bound = ceil((bootstrap_size * 0.05) / 2)
276+
bound = ceil((bootstrap_size * self.alpha) / 2)
275277
ci_low = bootstraps[bound]
276278
ci_high = bootstraps[bootstrap_size - bound]
277279

@@ -309,8 +311,11 @@ def __init__(
309311
df: pd.DataFrame = None,
310312
effect_modifiers: dict[Variable:Any] = None,
311313
formula: str = None,
314+
alpha: float = 0.05,
312315
):
313-
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
316+
super().__init__(
317+
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
318+
)
314319

315320
self.model = None
316321
if effect_modifiers is None:
@@ -344,7 +349,6 @@ def estimate_unit_ate(self) -> float:
344349
"""
345350
model = self._run_linear_regression()
346351
newline = "\n"
347-
print(model.conf_int())
348352
treatment = [self.treatment]
349353
if str(self.df.dtypes[self.treatment]) == "object":
350354
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
@@ -380,7 +384,7 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
380384
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
381385
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
382386
ate = t_test_results.effect[0]
383-
confidence_intervals = list(t_test_results.conf_int().flatten())
387+
confidence_intervals = list(t_test_results.conf_int(alpha=self.alpha).flatten())
384388
return ate, confidence_intervals
385389

386390
def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd.Series, pd.Series]:
@@ -442,25 +446,11 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
442446
443447
:return: The model after fitting to data.
444448
"""
445-
# 1. Reduce dataframe to contain only the necessary columns
446-
reduced_df = self.df.copy()
447-
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
448-
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
449-
reduced_df = reduced_df[~missing_rows]
450-
reduced_df = reduced_df.sort_values([self.treatment])
451-
logger.debug(reduced_df[necessary_cols])
452-
453-
# 2. Add intercept
454-
reduced_df["Intercept"] = 1 # self.intercept
455-
456-
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
457-
cols = [self.treatment]
458-
cols += [x for x in self.adjustment_set if x not in cols]
459449
model = smf.ols(formula=self.formula, data=self.df).fit()
460450
return model
461451

462452
def _get_confidence_intervals(self, model, treatment):
463-
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
453+
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
464454
ci_low, ci_high = (
465455
confidence_intervals[0].loc[treatment],
466456
confidence_intervals[1].loc[treatment],
@@ -527,7 +517,7 @@ def estimate_unit_ate(self, bootstrap_size=100):
527517
bootstraps = sorted(
528518
[self.estimate_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
529519
)
530-
bound = ceil((bootstrap_size * 0.05) / 2)
520+
bound = ceil((bootstrap_size * self.alpha) / 2)
531521
ci_low = bootstraps[bound]
532522
ci_high = bootstraps[bootstrap_size - bound]
533523

@@ -618,7 +608,7 @@ def estimate_cates(self) -> pd.DataFrame:
618608
# Obtain CATES and confidence intervals
619609
conditional_ates = model.effect(effect_modifier_df, T0=self.control_value, T1=self.treatment_value).flatten()
620610
[ci_low, ci_high] = model.effect_interval(
621-
effect_modifier_df, T0=self.control_value, T1=self.treatment_value, alpha=0.05
611+
effect_modifier_df, T0=self.control_value, T1=self.treatment_value, alpha=self.alpha
622612
)
623613

624614
# Merge results into a dataframe (CATE, confidence intervals, and effect modifier values)

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_should_not_cause_json_stub(self):
120120
"mutations": ["X1"],
121121
"name": "X1 _||_ Z",
122122
"formula": "Z ~ X1",
123+
"alpha": 0.05,
123124
"skip": True,
124125
},
125126
)

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def test_empty_adjustment_set(self):
6060
"Treatment value: 1\n"
6161
"Outcome: A\n"
6262
"Adjustment set: set()\n"
63+
"Formula: A ~ A\n"
6364
"ate: 0\n"
6465
),
6566
)

0 commit comments

Comments
 (0)