Skip to content

Commit bf0bbe3

Browse files
authored
Merge pull request #215 from CITCOM-project/test-adequacy
Test adequacy
2 parents 79dcb48 + 933a168 commit bf0bbe3

18 files changed

+425
-119
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
6262
solver.push()
6363
# Check that the row does not violate any scenario constraints
6464
# Need to explicitly cast variables to their specified type. Z3 will not take e.g. np.int64 to be an int.
65+
# Check that the row does not violate any scenario constraints
6566
model = [
6667
self.scenario.variables[var].z3
6768
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])

causal_testing/json_front/json_class.py

Lines changed: 55 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from causal_testing.testing.causal_test_result import CausalTestResult
2525
from causal_testing.testing.estimators import Estimator
2626
from causal_testing.testing.base_test_case import BaseTestCase
27+
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -66,9 +67,8 @@ def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None)
6667
data_paths = []
6768
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6869

69-
def setup(self, scenario: Scenario):
70+
def setup(self, scenario: Scenario, data=None):
7071
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
71-
data = []
7272
self.scenario = scenario
7373
self._get_scenario_variables()
7474
self.scenario.setup_treatment_variables()
@@ -81,9 +81,9 @@ def setup(self, scenario: Scenario):
8181
# Populate the data
8282
if self.input_paths.data_paths:
8383
data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84-
if len(data) == 0:
84+
if data is None or len(data) == 0:
8585
raise ValueError(
86-
"No data found, either provide a path to a file containing data or manually populate the .data "
86+
"No data found. Please either provide a path to a file containing data or manually populate the .data "
8787
"attribute with a dataframe before calling .setup()"
8888
)
8989
self.data_collector = ObservationalDataCollector(self.scenario, data)
@@ -128,40 +128,20 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
128128
if "skip" in test and test["skip"]:
129129
continue
130130
test["estimator"] = estimators[test["estimator"]]
131-
if "mutations" in test:
131+
# If we have specified concrete control and treatment value
132+
if "mutations" not in test:
133+
failed, msg = self._run_concrete_metamorphic_test(test, f_flag, effects)
134+
# If we have a variable to mutate
135+
else:
132136
if test["estimate_type"] == "coefficient":
133-
msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
137+
failed, msg = self._run_coefficient_test(test=test, f_flag=f_flag, effects=effects)
134138
else:
135-
msg = self._run_ate_test(test=test, f_flag=f_flag, effects=effects, mutates=mutates)
136-
self._append_to_file(msg, logging.INFO)
137-
else:
138-
outcome_variable = next(
139-
iter(test["expected_effect"])
140-
) # Take first key from dictionary of expected effect
141-
base_test_case = BaseTestCase(
142-
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
143-
outcome_variable=self.variables["outputs"][outcome_variable],
144-
)
145-
146-
causal_test_case = CausalTestCase(
147-
base_test_case=base_test_case,
148-
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
149-
control_value=test["control_value"],
150-
treatment_value=test["treatment_value"],
151-
estimate_type=test["estimate_type"],
152-
)
153-
154-
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
155-
156-
msg = (
157-
f"Executing concrete test: {test['name']} \n"
158-
+ f"treatment variable: {test['treatment_variable']} \n"
159-
+ f"outcome_variable = {outcome_variable} \n"
160-
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
161-
+ f"Result: {'FAILED' if failed else 'Passed'}"
162-
)
163-
print(msg)
164-
self._append_to_file(msg, logging.INFO)
139+
failed, msg = self._run_metamorphic_tests(
140+
test=test, f_flag=f_flag, effects=effects, mutates=mutates
141+
)
142+
test["failed"] = failed
143+
test["result"] = msg
144+
return self.test_plan["tests"]
165145

166146
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
167147
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
@@ -183,18 +163,45 @@ def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
183163
estimate_type="coefficient",
184164
effect_modifier_configuration={self.scenario.variables[v] for v in test.get("effect_modifiers", [])},
185165
)
186-
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
166+
failed, result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
187167
msg = (
188168
f"Executing test: {test['name']} \n"
189169
+ f" {causal_test_case} \n"
190170
+ " "
191-
+ ("\n ").join(str(result[1]).split("\n"))
171+
+ ("\n ").join(str(result).split("\n"))
192172
+ "==============\n"
193-
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
173+
+ f" Result: {'FAILED' if failed else 'Passed'}"
174+
)
175+
self._append_to_file(msg, logging.INFO)
176+
return failed, result
177+
178+
def _run_concrete_metamorphic_test(self, test: dict, f_flag: bool, effects: dict):
179+
outcome_variable = next(iter(test["expected_effect"])) # Take first key from dictionary of expected effect
180+
base_test_case = BaseTestCase(
181+
treatment_variable=self.variables["inputs"][test["treatment_variable"]],
182+
outcome_variable=self.variables["outputs"][outcome_variable],
194183
)
195-
return msg
196184

197-
def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
185+
causal_test_case = CausalTestCase(
186+
base_test_case=base_test_case,
187+
expected_causal_effect=effects[test["expected_effect"][outcome_variable]],
188+
control_value=test["control_value"],
189+
treatment_value=test["treatment_value"],
190+
estimate_type=test["estimate_type"],
191+
)
192+
failed, msg = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
193+
194+
msg = (
195+
f"Executing concrete test: {test['name']} \n"
196+
+ f"treatment variable: {test['treatment_variable']} \n"
197+
+ f"outcome_variable = {outcome_variable} \n"
198+
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
199+
+ f"Result: {'FAILED' if failed else 'Passed'}"
200+
)
201+
self._append_to_file(msg, logging.INFO)
202+
return failed, msg
203+
204+
def _run_metamorphic_tests(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
198205
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
199206
200207
:param test: Single JSON test definition stored in a mapping (dict)
@@ -226,7 +233,8 @@ def _run_ate_test(self, test: dict, f_flag: bool, effects: dict, mutates: dict):
226233
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
227234
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
228235
)
229-
return msg
236+
self._append_to_file(msg, logging.INFO)
237+
return failures, msg
230238

231239
def _execute_tests(self, concrete_tests, test, f_flag):
232240
failures = 0
@@ -265,9 +273,13 @@ def _execute_test_case(
265273
causal_test_result = causal_test_case.execute_test(
266274
estimator=estimation_model, data_collector=self.data_collector
267275
)
268-
269276
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
270277

278+
if "coverage" in test and test["coverage"]:
279+
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, self.data_collector)
280+
adequacy_metric.measure_adequacy()
281+
causal_test_result.adequacy = adequacy_metric
282+
271283
if causal_test_result.ci_low() is not None and causal_test_result.ci_high() is not None:
272284
result_string = (
273285
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
@@ -283,7 +295,6 @@ def _execute_test_case(
283295
f"got {result_string}"
284296
)
285297
failed = True
286-
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
287298
return failed, causal_test_result
288299

289300
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estimator:
@@ -294,7 +305,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
294305
data. Conditions should be in the query format detailed at
295306
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
296307
:returns:
297-
- causal_test_engine - Test Engine instance for the test being run
298308
- estimation_model - Estimator instance for the test being run
299309
"""
300310
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
@@ -370,7 +380,6 @@ def get_args(test_args=None) -> argparse.Namespace:
370380
parser.add_argument(
371381
"--log_path",
372382
help="Specify a directory to change the location of the log file",
373-
default="./json_frontend.log",
374383
)
375384
parser.add_argument(
376385
"--data_path",
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
This module contains code to measure various aspects of causal test adequacy.
3+
"""
4+
from itertools import combinations
5+
from copy import deepcopy
6+
import pandas as pd
7+
8+
from causal_testing.testing.causal_test_suite import CausalTestSuite
9+
from causal_testing.data_collection.data_collector import DataCollector
10+
from causal_testing.specification.causal_dag import CausalDAG
11+
from causal_testing.testing.estimators import Estimator
12+
from causal_testing.testing.causal_test_case import CausalTestCase
13+
14+
15+
class DAGAdequacy:
16+
"""
17+
Measures the adequacy of a given DAG by hos many edges and independences are tested.
18+
"""
19+
20+
def __init__(
21+
self,
22+
causal_dag: CausalDAG,
23+
test_suite: CausalTestSuite,
24+
):
25+
self.causal_dag = causal_dag
26+
self.test_suite = test_suite
27+
self.tested_pairs = None
28+
self.pairs_to_test = None
29+
self.untested_edges = None
30+
self.dag_adequacy = None
31+
32+
def measure_adequacy(self):
33+
"""
34+
Calculate the adequacy measurement, and populate the `dat_adequacy` field.
35+
"""
36+
self.tested_pairs = {(t.treatment_variable, t.outcome_variable) for t in self.test_suite}
37+
self.pairs_to_test = set(combinations(self.causal_dag.graph.nodes, 2))
38+
self.untested_edges = self.pairs_to_test.difference(self.tested_pairs)
39+
self.dag_adequacy = len(self.tested_pairs) / len(self.pairs_to_test)
40+
41+
def to_dict(self):
42+
"Returns the adequacy object as a dictionary."
43+
return {
44+
"causal_dag": self.causal_dag,
45+
"test_suite": self.test_suite,
46+
"tested_pairs": self.tested_pairs,
47+
"pairs_to_test": self.pairs_to_test,
48+
"untested_edges": self.untested_edges,
49+
"dag_adequacy": self.dag_adequacy,
50+
}
51+
52+
53+
class DataAdequacy:
54+
"""
55+
Measures the adequacy of a given test according to the Fisher kurtosis of the bootstrapped result.
56+
- Positive kurtoses indicate the model doesn't have enough data so is unstable.
57+
- Negative kurtoses indicate the model doesn't have enough data, but is too stable, indicating that the spread of
58+
inputs is insufficient.
59+
- Zero kurtosis is optimal.
60+
"""
61+
62+
def __init__(
63+
self, test_case: CausalTestCase, estimator: Estimator, data_collector: DataCollector, bootstrap_size: int = 100
64+
):
65+
self.test_case = test_case
66+
self.estimator = estimator
67+
self.data_collector = data_collector
68+
self.kurtosis = None
69+
self.outcomes = None
70+
self.bootstrap_size = bootstrap_size
71+
72+
def measure_adequacy(self):
73+
"""
74+
Calculate the adequacy measurement, and populate the data_adequacy field.
75+
"""
76+
results = []
77+
for i in range(self.bootstrap_size):
78+
estimator = deepcopy(self.estimator)
79+
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
80+
# try:
81+
results.append(self.test_case.execute_test(estimator, self.data_collector))
82+
# except np.LinAlgError:
83+
# continue
84+
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
85+
results = pd.DataFrame(c.to_dict() for c in results)[["effect_estimate", "ci_low", "ci_high"]]
86+
87+
def convert_to_df(field):
88+
converted = []
89+
for r in results[field]:
90+
if isinstance(r, float):
91+
converted.append(
92+
pd.DataFrame({self.test_case.base_test_case.treatment_variable.name: [r]}).transpose()
93+
)
94+
else:
95+
converted.append(r)
96+
return converted
97+
98+
for field in ["effect_estimate", "ci_low", "ci_high"]:
99+
results[field] = convert_to_df(field)
100+
101+
effect_estimate = pd.concat(results["effect_estimate"].tolist(), axis=1).transpose().reset_index(drop=True)
102+
self.kurtosis = effect_estimate.kurtosis()
103+
self.outcomes = sum(outcomes)
104+
105+
def to_dict(self):
106+
"Returns the adequacy object as a dictionary."
107+
return {"kurtosis": self.kurtosis.to_dict(), "bootstrap_size": self.bootstrap_size, "passing": self.outcomes}

causal_testing/testing/causal_test_case.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@ class CausalTestCase:
1818
"""
1919
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
2020
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
21-
specified.
22-
23-
The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment causes the
24-
model-under-test to produce the expected change. The CausalTestCase structure is designed for execution using the
25-
CausalTestEngine, using either execute_test() function to execute a single test case or packing CausalTestCases into
26-
a CausalTestSuite and executing them as a batch using the execute_test_suite() function.
21+
specified. The goal of a CausalTestCase is to test whether the intervention made to the control via the treatment
22+
causes the model-under-test to produce the expected change.
2723
"""
2824

2925
def __init__(
@@ -87,9 +83,6 @@ def execute_test(self, estimator: type(Estimator), data_collector: DataCollector
8783
if estimator.df is None:
8884
estimator.df = data_collector.collect_data()
8985

90-
logger.info("treatments: %s", self.treatment_variable.name)
91-
logger.info("outcomes: %s", self.outcome_variable)
92-
9386
causal_test_result = self._return_causal_test_results(estimator)
9487
return causal_test_result
9588

causal_testing/testing/causal_test_outcome.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,14 @@ def apply(self, res: CausalTestResult) -> bool:
4141
class NoEffect(CausalTestOutcome):
4242
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
4343

44-
def __init__(self, atol: float = 1e-10):
44+
def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
45+
"""
46+
:param atol: Arithmetic tolerance. The test will pass if the absolute value of the causal effect is less than
47+
atol.
48+
:param ctol: Categorical tolerance. The test will pass if this proportion of categories pass.
49+
"""
4550
self.atol = atol
51+
self.ctol = ctol
4652

4753
def apply(self, res: CausalTestResult) -> bool:
4854
if res.test_value.type == "ate":
@@ -52,14 +58,13 @@ def apply(self, res: CausalTestResult) -> bool:
5258
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
5359
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
5460

55-
if not all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)):
56-
print(
57-
"FAILING ON",
58-
[(ci_low, ci_high) for ci_low, ci_high in zip(ci_low, ci_high) if not ci_low < 0 < ci_high],
61+
return (
62+
sum(
63+
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
64+
for ci_low, ci_high, v in zip(ci_low, ci_high, value)
5965
)
60-
61-
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
62-
abs(v) < self.atol for v in value
66+
/ len(value)
67+
< self.ctol
6368
)
6469
if res.test_value.type == "risk_ratio":
6570
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)

0 commit comments

Comments
 (0)