Skip to content

Commit 147a22b

Browse files
committed
Fixed pytest
1 parent 20f39ef commit 147a22b

File tree

9 files changed

+71
-43
lines changed

9 files changed

+71
-43
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
expected_causal_effect: dict[Variable:CausalTestOutcome],
3232
effect_modifiers: set[Variable] = None,
3333
estimate_type: str = "ate",
34-
effect: str = "total"
34+
effect: str = "total",
3535
):
3636
assert treatment_variable in scenario.variables.values(), (
3737
"Treatment variables must be a subset of variables."
@@ -105,9 +105,11 @@ def _generate_concrete_tests(
105105
for c in self.intervention_constraints:
106106
optimizer.assert_and_track(c, str(c))
107107

108-
109108
for v in run_columns:
110-
optimizer.add_soft(self.scenario.variables[v].z3 == self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v]))
109+
optimizer.add_soft(
110+
self.scenario.variables[v].z3
111+
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
112+
)
111113

112114
# optimizer.add_soft([optimizer.add_soft(self.scenario.variables[v].z3 == self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])) for v in run_columns])
113115
if optimizer.check() == z3.unsat:
@@ -127,7 +129,7 @@ def _generate_concrete_tests(
127129
outcome_variables=list(self.expected_causal_effect.keys()),
128130
estimate_type=self.estimate_type,
129131
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
130-
effect=self.effect
132+
effect=self.effect,
131133
)
132134

133135
for v in self.scenario.inputs():
@@ -222,10 +224,14 @@ def generate_concrete_tests(
222224
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
223225
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]
224226

225-
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(set(zip(control_values, treatment_values))):
227+
if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
228+
set(zip(control_values, treatment_values))
229+
):
226230
pre_break = True
227231
break
228-
if issubclass(self.treatment_variable.datatype, Enum) and set(itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)).issubset(zip(control_values, treatment_values)):
232+
if issubclass(self.treatment_variable.datatype, Enum) and set(
233+
itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
234+
).issubset(zip(control_values, treatment_values)):
229235
pre_break = True
230236
break
231237
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
@@ -237,6 +243,6 @@ def generate_concrete_tests(
237243
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
238244
target_ks_score,
239245
ks_stats,
240-
len(concrete_tests)
246+
len(concrete_tests),
241247
)
242248
return concrete_tests, runs

causal_testing/json_front/json_class.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
103103
if "effect_modifiers" in test
104104
else {},
105105
estimate_type=test["estimate_type"],
106-
effect=test.get("effect", "total")
106+
effect=test.get("effect", "total"),
107107
)
108108
return abstract_test
109109

@@ -190,11 +190,7 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
190190
)
191191
if not test_passes:
192192
failed = True
193-
logger.warning(
194-
" FAILED- expected %s, got %s",
195-
causal_test_case.expected_causal_effect,
196-
result_string
197-
)
193+
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
198194
return failed
199195

200196
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:

causal_testing/specification/causal_dag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
255255
gam.add_edges_from(edges_to_add)
256256

257257
min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
258-
min_seps.remove(set(outcomes))
258+
if set(outcomes) in min_seps:
259+
min_seps.remove(set(outcomes))
259260
return min_seps
260261

261262
def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:

causal_testing/specification/variable.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def z3_types(datatype):
2222
if datatype in types:
2323
return types[datatype]
2424
if issubclass(datatype, Enum):
25-
dtype, _ = EnumSort(datatype.__name__, [x.value for x in datatype])
25+
dtype, _ = EnumSort(datatype.__name__, [str(x.value) for x in datatype])
2626
return lambda x: Const(x, dtype)
2727
if hasattr(datatype, "to_z3"):
2828
return datatype.to_z3()
@@ -161,7 +161,9 @@ def cast(self, val: Any) -> T:
161161
return float(val.numerator().as_long() / val.denominator().as_long())
162162
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
163163
return val.as_string()
164-
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (self.datatype == int or self.datatype == float or self.datatype == bool):
164+
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (
165+
self.datatype == int or self.datatype == float or self.datatype == bool
166+
):
165167
return self.datatype(val)
166168
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
167169
return self.datatype(str(val))

causal_testing/testing/estimators.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from econml.dml import CausalForestDML
99
from sklearn.ensemble import GradientBoostingRegressor
1010
from statsmodels.regression.linear_model import RegressionResultsWrapper
11+
from statsmodels.tools.sm_exceptions import PerfectSeparationError
1112

1213
from causal_testing.specification.variable import Variable
1314

@@ -144,7 +145,9 @@ def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
144145
outcome_col = reduced_df[list(self.outcome)]
145146
for col in treatment_and_adjustments_cols:
146147
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
147-
treatment_and_adjustments_cols = pd.get_dummies(treatment_and_adjustments_cols, columns=[col], drop_first=True)
148+
treatment_and_adjustments_cols = pd.get_dummies(
149+
treatment_and_adjustments_cols, columns=[col], drop_first=True
150+
)
148151
regression = sm.Logit(outcome_col, treatment_and_adjustments_cols)
149152
model = regression.fit()
150153
return model
@@ -181,9 +184,16 @@ def estimate_control_treatment(self, bootstrap_size=100) -> tuple[pd.Series, pd.
181184

182185
y = self.estimate(self.df)
183186

184-
bootstrap_samples = [self.estimate(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
185-
control, treatment = zip(*[(x.iloc[1], x.iloc[0]) for x in bootstrap_samples])
186-
187+
try:
188+
bootstrap_samples = [
189+
self.estimate(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)
190+
]
191+
control, treatment = zip(*[(x.iloc[1], x.iloc[0]) for x in bootstrap_samples])
192+
except PerfectSeparationError:
193+
logger.warning(
194+
"Perfect separation detected, results not available. Cannot calculate confidence intervals for such a small dataset."
195+
)
196+
return (y.iloc[1], None), (y.iloc[0], None)
187197

188198
# Delta method confidence intervals from
189199
# https://stackoverflow.com/questions/47414842/confidence-interval-of-probability-prediction-from-logistic-regression-statsmode
@@ -204,11 +214,17 @@ def estimate_ate(self, bootstrap_size=100) -> float:
204214
205215
:return: The estimated average treatment effect and 95% confidence intervals
206216
"""
207-
(control_outcome, control_bootstraps), (treatment_outcome, treatment_bootstraps) = self.estimate_control_treatment()
208-
217+
(control_outcome, control_bootstraps), (
218+
treatment_outcome,
219+
treatment_bootstraps,
220+
) = self.estimate_control_treatment()
209221
estimate = treatment_outcome - control_outcome
222+
223+
if control_bootstraps is None or treatment_bootstraps is None:
224+
return estimate, (None, None)
225+
210226
bootstraps = sorted(list(treatment_bootstraps - control_bootstraps))
211-
bound = int((bootstrap_size * 0.05)/2)
227+
bound = int((bootstrap_size * 0.05) / 2)
212228
ci_low = bootstraps[bound]
213229
ci_high = bootstraps[bootstrap_size - bound]
214230

@@ -227,11 +243,17 @@ def estimate_risk_ratio(self) -> float:
227243
228244
:return: The estimated risk ratio and 95% confidence intervals.
229245
"""
230-
(control_outcome, control_bootstraps), (treatment_outcome, treatment_bootstraps) = self.estimate_control_treatment()
231-
246+
(control_outcome, control_bootstraps), (
247+
treatment_outcome,
248+
treatment_bootstraps,
249+
) = self.estimate_control_treatment()
232250
estimate = treatment_outcome / control_outcome
251+
252+
if control_bootstraps is None or treatment_bootstraps is None:
253+
return estimate, (None, None)
254+
233255
bootstraps = sorted(list(treatment_bootstraps / control_bootstraps))
234-
bound = int((bootstrap_size * 0.05)/2)
256+
bound = int((bootstrap_size * 0.05) / 2)
235257
ci_low = bootstraps[bound]
236258
ci_high = bootstraps[bootstrap_size - bound]
237259

@@ -248,7 +270,7 @@ def estimate_unit_odds_ratio(self) -> float:
248270
249271
:return: The odds ratio. Confidence intervals are not yet supported.
250272
"""
251-
model = self._run_logistic_regression()
273+
model = self._run_logistic_regression(self.df)
252274
return np.exp(model.params[self.treatment[0]])
253275

254276

@@ -390,7 +412,6 @@ def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
390412
model = self._run_linear_regression()
391413
self.model = model
392414

393-
394415
x = pd.DataFrame()
395416
x[self.treatment[0]] = [self.treatment_values, self.control_values]
396417
x["Intercept"] = self.intercept
@@ -489,7 +510,9 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
489510
outcome_col = reduced_df[list(self.outcome)]
490511
for col in treatment_and_adjustments_cols:
491512
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
492-
treatment_and_adjustments_cols = pd.get_dummies(treatment_and_adjustments_cols, columns=[col], drop_first=True)
513+
treatment_and_adjustments_cols = pd.get_dummies(
514+
treatment_and_adjustments_cols, columns=[col], drop_first=True
515+
)
493516
regression = sm.OLS(outcome_col, treatment_and_adjustments_cols)
494517
model = regression.fit()
495518
return model

tests/generation_tests/test_abstract_test_case.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_generate_concrete_test_cases(self):
3636
abstract = AbstractCausalTestCase(
3737
scenario=scenario,
3838
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
39-
treatment_variables={self.X1},
39+
treatment_variable=self.X1,
4040
expected_causal_effect={self.Y: Positive()},
4141
effect_modifiers=None,
4242
)
@@ -50,7 +50,7 @@ def test_str(self):
5050
abstract = AbstractCausalTestCase(
5151
scenario=scenario,
5252
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
53-
treatment_variables={self.X1},
53+
treatment_variable=self.X1,
5454
expected_causal_effect={self.Y: Positive()},
5555
effect_modifiers=None,
5656
)
@@ -64,7 +64,7 @@ def test_datapath(self):
6464
abstract = AbstractCausalTestCase(
6565
scenario=scenario,
6666
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
67-
treatment_variables={self.X1},
67+
treatment_variable=self.X1,
6868
expected_causal_effect={self.Y: Positive()},
6969
effect_modifiers=None,
7070
)
@@ -76,7 +76,7 @@ def test_generate_concrete_test_cases_with_constraints(self):
7676
abstract = AbstractCausalTestCase(
7777
scenario=scenario,
7878
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
79-
treatment_variables={self.X1},
79+
treatment_variable=self.X1,
8080
expected_causal_effect={self.Y: Positive()},
8181
effect_modifiers=None,
8282
)
@@ -90,7 +90,7 @@ def test_generate_concrete_test_cases_with_effect_modifiers(self):
9090
abstract = AbstractCausalTestCase(
9191
scenario=scenario,
9292
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
93-
treatment_variables={self.X1},
93+
treatment_variable=self.X1,
9494
expected_causal_effect={self.Y: Positive()},
9595
effect_modifiers={self.X2},
9696
)
@@ -104,7 +104,7 @@ def test_generate_concrete_test_cases_rct(self):
104104
abstract = AbstractCausalTestCase(
105105
scenario=scenario,
106106
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
107-
treatment_variables={self.X1},
107+
treatment_variable=self.X1,
108108
expected_causal_effect={self.Y: Positive()},
109109
effect_modifiers=None,
110110
)
@@ -118,7 +118,7 @@ def test_infeasible_constraints(self):
118118
abstract = AbstractCausalTestCase(
119119
scenario=scenario,
120120
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
121-
treatment_variables={self.X1},
121+
treatment_variable=self.X1,
122122
expected_causal_effect={self.Y: Positive()},
123123
effect_modifiers=None,
124124
)
@@ -128,15 +128,15 @@ def test_infeasible_constraints(self):
128128
with self.assertWarns(Warning):
129129
concrete_tests, runs = abstract.generate_concrete_tests(4, rct=True, target_ks_score=0.1, hard_max=HARD_MAX)
130130
self.assertTrue(all((x > 2 for x in runs["X1"])))
131-
self.assertEqual(len(concrete_tests), HARD_MAX * NUM_STRATA)
131+
self.assertTrue(len(concrete_tests) <= HARD_MAX * NUM_STRATA)
132132

133133
def test_feasible_constraints(self):
134134
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
135135
scenario.setup_treatment_variables()
136136
abstract = AbstractCausalTestCase(
137137
scenario=scenario,
138138
intervention_constraints={scenario.treatment_variables[self.X1.name].z3 > self.X1.z3},
139-
treatment_variables={self.X1},
139+
treatment_variable=self.X1,
140140
expected_causal_effect={self.Y: Positive()},
141141
effect_modifiers=None,
142142
)

tests/specification_tests/test_causal_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def setUp(self) -> None:
107107
def test_direct_effect_adjustment_sets(self):
108108
causal_dag = CausalDAG(self.dag_dot_path)
109109
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X1"], ["Y"])
110-
self.assertEqual(list(adjustment_sets), [{"Y"}, {"D1", "Z"}, {"X2", "Z"}])
110+
self.assertEqual(list(adjustment_sets), [{"D1", "Z"}, {"X2", "Z"}])
111111

112112
def test_direct_effect_adjustment_sets_no_adjustment(self):
113113
causal_dag = CausalDAG(self.dag_dot_path)

tests/specification_tests/test_variable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class Color(Enum):
2020
Example enum class color.
2121
"""
2222

23-
RED = 1
24-
GREEN = 2
25-
BLUE = 3
23+
RED = "RED"
24+
GREEN = "GREEN"
25+
BLUE = "BLUE"
2626

2727
dtype, _ = z3.EnumSort("color", ("RED", "GREEN", "BLUE"))
2828
z3_color = z3.Const("color", dtype)

tests/testing_tests/test_causal_test_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_execute_test_observational_linear_regression_estimator_direct_effect(se
185185
causal_test_engine.scenario_execution_data_df,
186186
)
187187
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
188-
self.assertAlmostEqual(causal_test_result.test_value.value, 0, delta=1e-10)
188+
self.assertAlmostEqual(causal_test_result.test_value.value, 4, delta=1e-10)
189189

190190
def test_execute_test_observational_linear_regression_estimator_risk_ratio(self):
191191
"""Check that executing the causal test case returns the correct results for dummy data using a linear

0 commit comments

Comments
 (0)