Skip to content

Commit 53a0884

Browse files
black
1 parent 36b8c8a commit 53a0884

File tree

12 files changed

+603
-301
lines changed

12 files changed

+603
-301
lines changed

causal_testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
23
logger = logging.getLogger(__name__)
34
logger.setLevel(logging.INFO)
45
logger.addHandler(logging.StreamHandler())

causal_testing/data_collection/data_collector.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
2121
"""
2222
pass
2323

24-
def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.DataFrame:
24+
def filter_valid_data(
25+
self, data: pd.DataFrame, check_pos: bool = True
26+
) -> pd.DataFrame:
2527
"""Check is execution data is valid for the scenario-under-test.
2628
2729
Data is invalid if it does not meet the constraints specified in the scenario-under-test.
@@ -37,7 +39,9 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
3739

3840
if check_pos and not scenario_variables.issubset(data.columns):
3941
missing_variables = scenario_variables - set(data.columns)
40-
raise IndexError(f"Positivity violation: missing data for variables {missing_variables}.")
42+
raise IndexError(
43+
f"Positivity violation: missing data for variables {missing_variables}."
44+
)
4145

4246
# For each row, does it satisfy the constraints?
4347
solver = z3.Solver()
@@ -72,8 +76,10 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
7276
# How many rows did we drop?
7377
size_diff = len(data) - len(satisfying_data)
7478
if size_diff > 0:
75-
logger.warning(f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n"
76-
f"For example{unsat_core}")
79+
logger.warning(
80+
f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n"
81+
f"For example{unsat_core}"
82+
)
7783
return satisfying_data
7884

7985

@@ -83,8 +89,13 @@ class ExperimentalDataCollector(DataCollector):
8389
Users should implement these methods to collect data from their system.
8490
"""
8591

86-
def __init__(self, scenario: Scenario, control_input_configuration: dict, treatment_input_configuration: dict,
87-
n_repeats: int = 1):
92+
def __init__(
93+
self,
94+
scenario: Scenario,
95+
control_input_configuration: dict,
96+
treatment_input_configuration: dict,
97+
n_repeats: int = 1,
98+
):
8899
super().__init__(scenario)
89100
self.control_input_configuration = control_input_configuration
90101
self.treatment_input_configuration = treatment_input_configuration
@@ -97,13 +108,21 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
97108
:return: A pandas dataframe containing execution data for the system-under-test in both control and treatment
98109
executions.
99110
"""
100-
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
101-
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
102-
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
111+
control_results_df = self.run_system_with_input_configuration(
112+
self.control_input_configuration
113+
)
114+
treatment_results_df = self.run_system_with_input_configuration(
115+
self.treatment_input_configuration
116+
)
117+
results_df = pd.concat(
118+
[control_results_df, treatment_results_df], ignore_index=True
119+
)
103120
return results_df
104121

105122
@abstractmethod
106-
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:
123+
def run_system_with_input_configuration(
124+
self, input_configuration: dict
125+
) -> pd.DataFrame:
107126
"""Run the system with a given input configuration and return the resulting execution data.
108127
109128
:param input_configuration: A dictionary which maps a subset of inputs to values.

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,45 +25,52 @@ def __init__(
2525
scenario: Scenario,
2626
intervention_constraints: set[z3.ExprRef],
2727
treatment_variables: set[Variable],
28-
expected_causal_effect: dict[Variable: CausalTestOutcome],
28+
expected_causal_effect: dict[Variable:CausalTestOutcome],
2929
effect_modifiers: set[Variable] = None,
30-
estimate_type: str = "ate"
30+
estimate_type: str = "ate",
3131
):
3232
assert treatment_variables.issubset(scenario.variables.values()), (
3333
"Treatment variables must be a subset of variables."
3434
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
3535
)
3636

37-
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
37+
assert (
38+
len(expected_causal_effect) == 1
39+
), "We currently only support tests with one causal outcome"
3840

3941
self.scenario = scenario
4042
self.intervention_constraints = intervention_constraints
4143
self.treatment_variables = treatment_variables
4244
self.expected_causal_effect = expected_causal_effect
43-
self.estimate_type=estimate_type
45+
self.estimate_type = estimate_type
4446

4547
if effect_modifiers is not None:
4648
self.effect_modifiers = effect_modifiers
4749
else:
4850
self.effect_modifiers = {}
4951

5052
def __str__(self):
51-
outcome_string = " and ".join([f"the effect on {var} should be {str(effect)}" for var, effect in self.expected_causal_effect.items()])
52-
return (
53-
f"When we apply intervention {self.intervention_constraints}, {outcome_string}"
53+
outcome_string = " and ".join(
54+
[
55+
f"the effect on {var} should be {str(effect)}"
56+
for var, effect in self.expected_causal_effect.items()
57+
]
5458
)
59+
return f"When we apply intervention {self.intervention_constraints}, {outcome_string}"
5560

5661
def datapath(self):
5762
def sanitise(string):
5863
return "".join([x for x in string if x.isalnum()])
5964

6065
return (
6166
sanitise("-".join([str(c) for c in self.intervention_constraints]))
62-
+ "_"+'-'.join([f"{v.name}_{e}" for v, e in self.expected_causal_effect.items()])
67+
+ "_"
68+
+ "-".join(
69+
[f"{v.name}_{e}" for v, e in self.expected_causal_effect.items()]
70+
)
6371
+ ".csv"
6472
)
6573

66-
6774
def _generate_concrete_tests(
6875
self, sample_size: int, rct: bool = False, seed: int = 0
6976
) -> tuple[list[CausalTestCase], pd.DataFrame]:
@@ -80,8 +87,9 @@ def _generate_concrete_tests(
8087

8188
concrete_tests = []
8289
runs = []
83-
run_columns = sorted([v.name for v in self.scenario.variables.values() if v.distribution])
84-
90+
run_columns = sorted(
91+
[v.name for v in self.scenario.variables.values() if v.distribution]
92+
)
8593

8694
# Generate the Latin Hypercube samples and put into a dataframe
8795
# lhsmdu.setRandomSeed(seed+i)
@@ -103,7 +111,9 @@ def _generate_concrete_tests(
103111
for c in self.intervention_constraints:
104112
optimizer.assert_and_track(c, str(c))
105113

106-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
114+
optimizer.add_soft(
115+
[self.scenario.variables[v].z3 == row[v] for v in run_columns]
116+
)
107117
if optimizer.check() == z3.unsat:
108118
logger.warning(
109119
"Satisfiability of test case was unsat.\n"
@@ -122,9 +132,9 @@ def _generate_concrete_tests(
122132
expected_causal_effect=list(self.expected_causal_effect.values())[0],
123133
outcome_variables=list(self.expected_causal_effect.keys()),
124134
estimate_type=self.estimate_type,
125-
effect_modifier_configuration = {
135+
effect_modifier_configuration={
126136
v: v.cast(model[v.z3]) for v in self.effect_modifiers
127-
}
137+
},
128138
)
129139

130140
for v in self.scenario.inputs():
@@ -160,9 +170,13 @@ def _generate_concrete_tests(
160170

161171
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])
162172

163-
164173
def generate_concrete_tests(
165-
self, sample_size: int, target_ks_score: float = None, rct: bool = False, seed: int = 0, hard_max: int = 1000
174+
self,
175+
sample_size: int,
176+
target_ks_score: float = None,
177+
rct: bool = False,
178+
seed: int = 0,
179+
hard_max: int = 1000,
166180
) -> tuple[list[CausalTestCase], pd.DataFrame]:
167181
"""Generates a list of `num` concrete test cases.
168182
@@ -189,14 +203,22 @@ def generate_concrete_tests(
189203
ks_stats = []
190204

191205
for i in range(hard_max):
192-
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed+i)
206+
concrete_tests_, runs_ = self._generate_concrete_tests(
207+
sample_size, rct, seed + i
208+
)
193209
concrete_tests += concrete_tests_
194210
runs = pd.concat([runs, runs_])
195-
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
196-
211+
assert (
212+
concrete_tests_ not in concrete_tests
213+
), "Duplicate entries unlikely unless something went wrong"
197214

198-
control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
199-
ks_stats = {var: stats.kstest(control_configs[var], var.distribution.cdf).statistic for var in control_configs.columns}
215+
control_configs = pd.DataFrame(
216+
[test.control_input_configuration for test in concrete_tests]
217+
)
218+
ks_stats = {
219+
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
220+
for var in control_configs.columns
221+
}
200222
# Putting treatment and control values in messes it up because the two are not independent...
201223
# This is potentially problematic as constraints might mean we don't get good coverage if we use control values alone
202224
# We might then need to carefully craft our _control value_ generating distributions so that we can get good coverage
@@ -205,11 +227,29 @@ def generate_concrete_tests(
205227
# treatment_configs = pd.DataFrame([test.treatment_input_configuration for test in concrete_tests])
206228
# both_configs = pd.concat([control_configs, treatment_configs])
207229
# ks_stats = {var: stats.kstest(both_configs[var], var.distribution.cdf).statistic for var in both_configs.columns}
208-
effect_modifier_configs = pd.DataFrame([test.effect_modifier_configuration for test in concrete_tests])
209-
ks_stats.update({var: stats.kstest(effect_modifier_configs[var], var.distribution.cdf).statistic for var in effect_modifier_configs.columns})
210-
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
230+
effect_modifier_configs = pd.DataFrame(
231+
[test.effect_modifier_configuration for test in concrete_tests]
232+
)
233+
ks_stats.update(
234+
{
235+
var: stats.kstest(
236+
effect_modifier_configs[var], var.distribution.cdf
237+
).statistic
238+
for var in effect_modifier_configs.columns
239+
}
240+
)
241+
if target_ks_score and all(
242+
(stat <= target_ks_score for stat in ks_stats.values())
243+
):
211244
break
212245

213-
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
214-
logger.error("Hard max of %s reached but could not achieve target ks_score of %s. Got %s.", hard_max, target_ks_score, ks_stats)
246+
if target_ks_score is not None and not all(
247+
(stat <= target_ks_score for stat in ks_stats.values())
248+
):
249+
logger.error(
250+
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
251+
hard_max,
252+
target_ks_score,
253+
ks_stats,
254+
)
215255
return concrete_tests, runs

0 commit comments

Comments
 (0)