Skip to content

Commit b7f1185

Browse files
black
1 parent 06ec318 commit b7f1185

File tree

10 files changed

+117
-365
lines changed

10 files changed

+117
-365
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
2121
"""
2222
pass
2323

24-
def filter_valid_data(
25-
self, data: pd.DataFrame, check_pos: bool = True
26-
) -> pd.DataFrame:
24+
def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.DataFrame:
2725
"""Check is execution data is valid for the scenario-under-test.
2826
2927
Data is invalid if it does not meet the constraints specified in the scenario-under-test.
@@ -39,9 +37,7 @@ def filter_valid_data(
3937

4038
if check_pos and not scenario_variables.issubset(data.columns):
4139
missing_variables = scenario_variables - set(data.columns)
42-
raise IndexError(
43-
f"Positivity violation: missing data for variables {missing_variables}."
44-
)
40+
raise IndexError(f"Positivity violation: missing data for variables {missing_variables}.")
4541

4642
# For each row, does it satisfy the constraints?
4743
solver = z3.Solver()
@@ -54,9 +50,7 @@ def filter_valid_data(
5450
# Need to explicitly cast variables to their specified type. Z3 will not take e.g. np.int64 to be an int.
5551
model = [
5652
self.scenario.variables[var].z3
57-
== self.scenario.variables[var].z3_val(
58-
self.scenario.variables[var].z3, row[var]
59-
)
53+
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
6054
for var in self.scenario.variables
6155
]
6256
for c in model:
@@ -77,8 +71,7 @@ def filter_valid_data(
7771
size_diff = len(data) - len(satisfying_data)
7872
if size_diff > 0:
7973
logger.warning(
80-
f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n"
81-
f"For example{unsat_core}"
74+
f"Discarded {size_diff}/{len(data)} values due to constraint violations.\n" f"For example{unsat_core}"
8275
)
8376
return satisfying_data
8477

@@ -108,21 +101,13 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
108101
:return: A pandas dataframe containing execution data for the system-under-test in both control and treatment
109102
executions.
110103
"""
111-
control_results_df = self.run_system_with_input_configuration(
112-
self.control_input_configuration
113-
)
114-
treatment_results_df = self.run_system_with_input_configuration(
115-
self.treatment_input_configuration
116-
)
117-
results_df = pd.concat(
118-
[control_results_df, treatment_results_df], ignore_index=True
119-
)
104+
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
105+
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
106+
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
120107
return results_df
121108

122109
@abstractmethod
123-
def run_system_with_input_configuration(
124-
self, input_configuration: dict
125-
) -> pd.DataFrame:
110+
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:
126111
"""Run the system with a given input configuration and return the resulting execution data.
127112
128113
:param input_configuration: A dictionary which maps a subset of inputs to values.

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ def __init__(
3434
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
3535
)
3636

37-
assert (
38-
len(expected_causal_effect) == 1
39-
), "We currently only support tests with one causal outcome"
37+
assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"
4038

4139
self.scenario = scenario
4240
self.intervention_constraints = intervention_constraints
@@ -51,10 +49,7 @@ def __init__(
5149

5250
def __str__(self):
5351
outcome_string = " and ".join(
54-
[
55-
f"the effect on {var} should be {str(effect)}"
56-
for var, effect in self.expected_causal_effect.items()
57-
]
52+
[f"the effect on {var} should be {str(effect)}" for var, effect in self.expected_causal_effect.items()]
5853
)
5954
return f"When we apply intervention {self.intervention_constraints}, {outcome_string}"
6055

@@ -65,9 +60,7 @@ def sanitise(string):
6560
return (
6661
sanitise("-".join([str(c) for c in self.intervention_constraints]))
6762
+ "_"
68-
+ "-".join(
69-
[f"{v.name}_{e}" for v, e in self.expected_causal_effect.items()]
70-
)
63+
+ "-".join([f"{v.name}_{e}" for v, e in self.expected_causal_effect.items()])
7164
+ ".csv"
7265
)
7366

@@ -87,9 +80,7 @@ def _generate_concrete_tests(
8780

8881
concrete_tests = []
8982
runs = []
90-
run_columns = sorted(
91-
[v.name for v in self.scenario.variables.values() if v.distribution]
92-
)
83+
run_columns = sorted([v.name for v in self.scenario.variables.values() if v.distribution])
9384

9485
# Generate the Latin Hypercube samples and put into a dataframe
9586
# lhsmdu.setRandomSeed(seed+i)
@@ -100,9 +91,7 @@ def _generate_concrete_tests(
10091
# Project the samples to the variables' distributions
10192
for name in run_columns:
10293
var = self.scenario.variables[name]
103-
samples[var.name] = lhsmdu.inverseTransformSample(
104-
var.distribution, samples[var.name]
105-
)
94+
samples[var.name] = lhsmdu.inverseTransformSample(var.distribution, samples[var.name])
10695

10796
for index, row in samples.iterrows():
10897
optimizer = z3.Optimize()
@@ -111,9 +100,7 @@ def _generate_concrete_tests(
111100
for c in self.intervention_constraints:
112101
optimizer.assert_and_track(c, str(c))
113102

114-
optimizer.add_soft(
115-
[self.scenario.variables[v].z3 == row[v] for v in run_columns]
116-
)
103+
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
117104
if optimizer.check() == z3.unsat:
118105
logger.warning(
119106
"Satisfiability of test case was unsat.\n"
@@ -122,26 +109,19 @@ def _generate_concrete_tests(
122109
model = optimizer.model()
123110

124111
concrete_test = CausalTestCase(
125-
control_input_configuration={
126-
v: v.cast(model[v.z3]) for v in self.treatment_variables
127-
},
112+
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
128113
treatment_input_configuration={
129-
v: v.cast(model[self.scenario.treatment_variables[v.name].z3])
130-
for v in self.treatment_variables
114+
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
131115
},
132116
expected_causal_effect=list(self.expected_causal_effect.values())[0],
133117
outcome_variables=list(self.expected_causal_effect.keys()),
134118
estimate_type=self.estimate_type,
135-
effect_modifier_configuration={
136-
v: v.cast(model[v.z3]) for v in self.effect_modifiers
137-
},
119+
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
138120
)
139121

140122
for v in self.scenario.inputs():
141123
if row[v.name] != v.cast(model[v.z3]):
142-
constraints = "\n ".join(
143-
[str(c) for c in self.scenario.constraints if v.name in str(c)]
144-
)
124+
constraints = "\n ".join([str(c) for c in self.scenario.constraints if v.name in str(c)])
145125
logger.warning(
146126
f"Unable to set variable {v.name} to {row[v.name]} because of constraints\n"
147127
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
@@ -150,21 +130,14 @@ def _generate_concrete_tests(
150130
concrete_tests.append(concrete_test)
151131
# Control run
152132
control_run = {
153-
v.name: v.cast(model[v.z3])
154-
for v in self.scenario.variables.values()
155-
if v.name in run_columns
133+
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
156134
}
157135
control_run["bin"] = index
158136
runs.append(control_run)
159137
# Treatment run
160138
if rct:
161139
treatment_run = control_run.copy()
162-
treatment_run.update(
163-
{
164-
k.name: v
165-
for k, v in concrete_test.treatment_input_configuration.items()
166-
}
167-
)
140+
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
168141
treatment_run["bin"] = index
169142
runs.append(treatment_run)
170143

@@ -203,18 +176,12 @@ def generate_concrete_tests(
203176
ks_stats = []
204177

205178
for i in range(hard_max):
206-
concrete_tests_, runs_ = self._generate_concrete_tests(
207-
sample_size, rct, seed + i
208-
)
179+
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
209180
concrete_tests += concrete_tests_
210181
runs = pd.concat([runs, runs_])
211-
assert (
212-
concrete_tests_ not in concrete_tests
213-
), "Duplicate entries unlikely unless something went wrong"
182+
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"
214183

215-
control_configs = pd.DataFrame(
216-
[test.control_input_configuration for test in concrete_tests]
217-
)
184+
control_configs = pd.DataFrame([test.control_input_configuration for test in concrete_tests])
218185
ks_stats = {
219186
var: stats.kstest(control_configs[var], var.distribution.cdf).statistic
220187
for var in control_configs.columns
@@ -227,25 +194,17 @@ def generate_concrete_tests(
227194
# treatment_configs = pd.DataFrame([test.treatment_input_configuration for test in concrete_tests])
228195
# both_configs = pd.concat([control_configs, treatment_configs])
229196
# ks_stats = {var: stats.kstest(both_configs[var], var.distribution.cdf).statistic for var in both_configs.columns}
230-
effect_modifier_configs = pd.DataFrame(
231-
[test.effect_modifier_configuration for test in concrete_tests]
232-
)
197+
effect_modifier_configs = pd.DataFrame([test.effect_modifier_configuration for test in concrete_tests])
233198
ks_stats.update(
234199
{
235-
var: stats.kstest(
236-
effect_modifier_configs[var], var.distribution.cdf
237-
).statistic
200+
var: stats.kstest(effect_modifier_configs[var], var.distribution.cdf).statistic
238201
for var in effect_modifier_configs.columns
239202
}
240203
)
241-
if target_ks_score and all(
242-
(stat <= target_ks_score for stat in ks_stats.values())
243-
):
204+
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
244205
break
245206

246-
if target_ks_score is not None and not all(
247-
(stat <= target_ks_score for stat in ks_stats.values())
248-
):
207+
if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
249208
logger.error(
250209
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
251210
hard_max,

causal_testing/json_front/json_class.py

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,19 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
7474
"""
7575
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
7676
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
77-
self.metas = (
78-
[Meta(i["name"], i["type"], i["populate"]) for i in metas]
79-
if metas
80-
else []
81-
)
77+
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
8278

8379
def setup(self):
8480
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
85-
self.modelling_scenario = Scenario(
86-
self.inputs + self.outputs + self.metas, None
87-
)
81+
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
8882
self.modelling_scenario.setup_treatment_variables()
8983
self.causal_specification = CausalSpecification(
9084
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path)
9185
)
9286
self._json_parse()
9387
self._populate_metas()
9488

95-
def execute_tests(
96-
self, effects: dict, mutates: dict, estimators: dict, f_flag: bool
97-
):
89+
def execute_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag: bool):
9890
"""Runs and evaluates each test case specified in the JSON input
9991
10092
:param effects: Dictionary mapping effect class instances to string representations.
@@ -110,20 +102,13 @@ def execute_tests(
110102

111103
abstract_test = AbstractCausalTestCase(
112104
scenario=self.modelling_scenario,
113-
intervention_constraints=[
114-
mutates[v](k) for k, v in test["mutations"].items()
115-
],
116-
treatment_variables={
117-
self.modelling_scenario.variables[v] for v in test["mutations"]
118-
},
105+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
106+
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
119107
expected_causal_effect={
120108
self.modelling_scenario.variables[variable]: effects[effect]
121109
for variable, effect in test["expectedEffect"].items()
122110
},
123-
effect_modifiers={
124-
self.modelling_scenario.variables[v]
125-
for v in test["effect_modifiers"]
126-
}
111+
effect_modifiers={self.modelling_scenario.variables[v] for v in test["effect_modifiers"]}
127112
if "effect_modifiers" in test
128113
else {},
129114
estimate_type=test["estimate_type"],
@@ -132,17 +117,11 @@ def execute_tests(
132117
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
133118
logger.info("Executing test: %s", test["name"])
134119
logger.info(abstract_test)
135-
logger.info(
136-
[(v.name, v.distribution) for v in abstract_test.treatment_variables]
137-
)
138-
logger.info(
139-
"Number of concrete tests for test case: %s", str(len(concrete_tests))
140-
)
120+
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
121+
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
141122
for concrete_test in concrete_tests:
142123
executed_tests += 1
143-
failed = self._execute_test_case(
144-
concrete_test, estimators[test["estimator"]], f_flag
145-
)
124+
failed = self._execute_test_case(concrete_test, estimators[test["estimator"]], f_flag)
146125
if failed:
147126
failures += 1
148127

@@ -170,9 +149,7 @@ def _populate_metas(self):
170149
var.distribution = getattr(scipy.stats, dist)(**params)
171150
logger.info(var.name + f"{dist}({params})")
172151

173-
def _execute_test_case(
174-
self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool
175-
) -> bool:
152+
def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
176153
"""Executes a singular test case, prints the results and returns the test case result
177154
:param causal_test_case: The concrete test case to be executed
178155
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
@@ -181,9 +158,7 @@ def _execute_test_case(
181158
"""
182159
failed = False
183160

184-
causal_test_engine, estimation_model = self._setup_test(
185-
causal_test_case, estimator
186-
)
161+
causal_test_engine, estimation_model = self._setup_test(causal_test_case, estimator)
187162
causal_test_result = causal_test_engine.execute_test(
188163
estimation_model, estimate_type=causal_test_case.estimate_type
189164
)
@@ -192,7 +167,9 @@ def _execute_test_case(
192167

193168
result_string = str()
194169
if causal_test_result.ci_low() and causal_test_result.ci_high():
195-
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
170+
result_string = (
171+
f"{causal_test_result.ci_low()} < {causal_test_result.ate} < {causal_test_result.ci_high()}"
172+
)
196173
else:
197174
result_string = causal_test_result.ate
198175
if f_flag:
@@ -209,34 +186,22 @@ def _execute_test_case(
209186
)
210187
return failed
211188

212-
def _setup_test(
213-
self, causal_test_case: CausalTestCase, estimator: Estimator
214-
) -> tuple[CausalTestEngine, Estimator]:
189+
def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
215190
"""Create the necessary inputs for a single test case
216191
:param causal_test_case: The concrete test case to be executed
217192
:returns:
218193
- causal_test_engine - Test Engine instance for the test being run
219194
- estimation_model - Estimator instance for the test being run
220195
"""
221-
data_collector = ObservationalDataCollector(
222-
self.modelling_scenario, self.data_path
223-
)
224-
causal_test_engine = CausalTestEngine(
225-
causal_test_case, self.causal_specification, data_collector
226-
)
196+
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
197+
causal_test_engine = CausalTestEngine(causal_test_case, self.causal_specification, data_collector)
227198
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
228199
treatment_vars = list(causal_test_case.treatment_input_configuration)
229-
minimal_adjustment_set = minimal_adjustment_set - {
230-
v.name for v in treatment_vars
231-
}
200+
minimal_adjustment_set = minimal_adjustment_set - {v.name for v in treatment_vars}
232201
estimation_model = estimator(
233202
(list(treatment_vars)[0].name,),
234-
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][
235-
0
236-
],
237-
[causal_test_case.control_input_configuration[v] for v in treatment_vars][
238-
0
239-
],
203+
[causal_test_case.treatment_input_configuration[v] for v in treatment_vars][0],
204+
[causal_test_case.control_input_configuration[v] for v in treatment_vars][0],
240205
minimal_adjustment_set,
241206
(list(causal_test_case.outcome_variables)[0].name,),
242207
causal_test_engine.scenario_execution_data_df,

0 commit comments

Comments
 (0)