Skip to content

Commit 7e443b7

Browse files
committed
Case study code now works
1 parent 132a0fe commit 7e443b7

File tree

6 files changed

+55
-24
lines changed

6 files changed

+55
-24
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
6161
self.scenario.variables[var].z3
6262
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
6363
for var in self.scenario.variables
64-
if var in row
64+
if var in row and not pd.isnull(row[var])
6565
]
6666
for c in model:
6767
solver.assert_and_track(c, f"model: {c}")
@@ -147,7 +147,8 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
147147

148148
execution_data_df = self.data
149149
for meta in self.scenario.metas():
150-
meta.populate(execution_data_df)
150+
if meta.name not in self.data:
151+
meta.populate(execution_data_df)
151152
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
152153
for var_name, var in self.scenario.variables.items():
153154
if issubclass(var.datatype, Enum):

causal_testing/json_front/json_class.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,14 @@ def _populate_metas(self):
170170

171171
for var in self.variables.metas + self.variables.outputs:
172172
if not var.distribution:
173-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
174-
fitter.fit()
175-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
176-
var.distribution = getattr(scipy.stats, dist)(**params)
177-
logger.info(var.name + f" {dist}({params})")
173+
try:
174+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
175+
fitter.fit()
176+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
177+
var.distribution = getattr(scipy.stats, dist)(**params)
178+
logger.info(var.name + f" {dist}({params})")
179+
except:
180+
logger.warn(f"Could not fit distriubtion for {var.name}.")
178181

179182
def _execute_test_case(
180183
self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool, conditions: list[str]
@@ -224,7 +227,7 @@ def _setup_test(
224227
- estimation_model - Estimator instance for the test being run
225228
"""
226229

227-
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data.query(" & ".join(conditions)))
230+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data.query(" & ".join(conditions)) if conditions else self.data)
228231
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
229232

230233
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)

causal_testing/specification/metamorphic_relation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
102102
def assertion(self, source_output, follow_up_output):
103103
"""An assertion that should be applied to an individual metamorphic test run."""
104104

105+
@abstractmethod
106+
def to_json_stub(self, skip=True) -> dict:
107+
"""Convert to a JSON frontend stub string for user customisation"""
108+
105109
@abstractmethod
106110
def test_oracle(self, test_results):
107111
"""A test oracle that assert whether the MR holds or not based on ALL test results.
@@ -129,6 +133,18 @@ def test_oracle(self, test_results):
129133
self.tests
130134
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
131135

136+
def to_json_stub(self, skip=True) -> dict:
137+
"""Convert to a JSON frontend stub string for user customisation"""
138+
return {
139+
"name": str(self),
140+
"estimator": "LinearRegressionEstimator",
141+
"estimate_type": "coefficient",
142+
"effect": "direct",
143+
"mutations": [self.treatment_var],
144+
"expectedEffect": {self.output_var: "SomeEffect"},
145+
"skip": skip
146+
}
147+
132148
def __str__(self):
133149
formatted_str = f"{self.treatment_var} --> {self.output_var}"
134150
if self.adjustment_vars:
@@ -149,6 +165,19 @@ def test_oracle(self, test_results):
149165
len(test_results["fail"]) == 0
150166
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
151167

168+
169+
def to_json_stub(self, skip=True) -> dict:
170+
"""Convert to a JSON frontend stub string for user customisation"""
171+
return {
172+
"name": str(self),
173+
"estimator": "LinearRegressionEstimator",
174+
"estimate_type": "coefficient",
175+
"effect": "direct",
176+
"mutations": [self.treatment_var],
177+
"expectedEffect": {self.output_var: "NoEffect"},
178+
"skip": skip
179+
}
180+
152181
def __str__(self):
153182
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
154183
if self.adjustment_vars:

causal_testing/testing/causal_test_outcome.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ class NoEffect(CausalTestOutcome):
3737
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
3838

3939
def apply(self, res: CausalTestResult) -> bool:
40-
if res.test_value.type == "ate":
40+
print("RESULT", res)
41+
if res.test_value.type == "ate" or res.test_value.type == "coefficient":
4142
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
4243
if res.test_value.type == "risk_ratio":
4344
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)

causal_testing/testing/causal_test_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def __init__(
4444
def __str__(self):
4545
base_str = (
4646
f"Causal Test Result\n==============\n"
47-
f"Treatment: {self.estimator.treatment[0]}\n"
47+
f"Treatment: {self.estimator.treatment}\n"
4848
f"Control value: {self.estimator.control_value}\n"
4949
f"Treatment value: {self.estimator.treatment_value}\n"
50-
f"Outcome: {self.estimator.outcome[0]}\n"
50+
f"Outcome: {self.estimator.outcome}\n"
5151
f"Adjustment set: {self.adjustment_set}\n"
5252
f"{self.test_value.type}: {self.test_value.value}\n"
5353
)

causal_testing/testing/estimators.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,6 @@ def add_modelling_assumptions(self):
7373
must hold if the resulting causal inference is to be considered valid.
7474
"""
7575

76-
@abstractmethod
77-
def estimate_ate(self) -> float:
78-
"""
79-
Estimate the unit effect of the treatment on the outcome. That is, the coefficient of the treatment variable
80-
in the linear regression equation.
81-
:return: The intercept and coefficient of the linear regression equation
82-
"""
83-
8476
def compute_confidence_intervals(self) -> list[float, float]:
8577
"""
8678
Estimate the 95% Wald confidence intervals for the effect of changing the treatment from control values to
@@ -535,21 +527,26 @@ def add_modelling_assumptions(self):
535527
(iii) Instrument and outcome do not share causes
536528
"""
537529

538-
def estimate_coefficient(self):
530+
def estimate_coefficient(self, df):
539531
"""
540532
Estimate the linear regression coefficient of the treatment on the outcome.
541533
"""
542534
# Estimate the total effect of instrument I on outcome Y = abI + c1
543-
ab = sm.OLS(self.df[self.outcome], self.df[[self.instrument]]).fit().params[self.instrument]
535+
ab = sm.OLS(df[self.outcome], df[[self.instrument]]).fit().params[self.instrument]
544536

545537
# Estimate the direct effect of instrument I on treatment X = aI + c1
546-
a = sm.OLS(self.df[self.treatment], self.df[[self.instrument]]).fit().params[self.instrument]
538+
a = sm.OLS(df[self.treatment], df[[self.instrument]]).fit().params[self.instrument]
547539

548540
# Estimate the coefficient of I on X by cancelling
549541
return ab / a
550542

551-
def estimate_ate(self):
552-
return (self.treatment_value - self.control_value) * self.estimate_coefficient(), (None, None)
543+
def estimate_unit_ate(self, bootstrap_size=100):
544+
bootstraps = sorted([self.estimate_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)])
545+
bound = ceil((bootstrap_size * 0.05) / 2)
546+
ci_low = bootstraps[bound]
547+
ci_high = bootstraps[bootstrap_size - bound]
548+
549+
return self.estimate_coefficient(self.df), (ci_low, ci_high)
553550

554551

555552
class CausalForestEstimator(Estimator):

0 commit comments

Comments
 (0)