Skip to content

Commit 8fafe9d

Browse files
Merge pull request #177 from CITCOM-project/json-cate
Json cate
2 parents 72c3997 + 1203490 commit 8fafe9d

25 files changed

+448
-174
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
6161
self.scenario.variables[var].z3
6262
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
6363
for var in self.scenario.variables
64-
if var in row
64+
if var in row and not pd.isnull(row[var])
6565
]
6666
for c in model:
6767
solver.assert_and_track(c, f"model: {c}")
@@ -147,7 +147,8 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
147147

148148
execution_data_df = self.data
149149
for meta in self.scenario.metas():
150-
meta.populate(execution_data_df)
150+
if meta.name not in self.data:
151+
meta.populate(execution_data_df)
151152
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
152153
for var_name, var in self.scenario.variables.items():
153154
if issubclass(var.datatype, Enum):

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _generate_concrete_tests(
131131
)
132132

133133
for v in self.scenario.inputs():
134-
if row[v.name] != v.cast(model[v.z3]):
134+
if v.name in row and row[v.name] != v.cast(model[v.z3]):
135135
constraints = "\n ".join([str(c) for c in self.scenario.constraints if v.name in str(c)])
136136
logger.warning(
137137
f"Unable to set variable {v.name} to {row[v.name]} because of constraints\n"

causal_testing/generation/enum_gen.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""This module contains the class EnumGen, which allows us to easily create
2+
generating uniform distributions from enums."""
3+
4+
from enum import Enum
5+
from scipy.stats import rv_discrete
6+
import numpy as np
7+
8+
9+
class EnumGen(rv_discrete):
10+
"""This class allows us to easily create generating uniform distributions
11+
from enums. This is helpful for generating concrete test inputs from
12+
abstract test cases."""
13+
14+
def __init__(self, datatype: Enum):
15+
super().__init__()
16+
self.datatype = dict(enumerate(datatype, 1))
17+
self.inverse_dt = {v: k for k, v in self.datatype.items()}
18+
19+
def ppf(self, q):
20+
"""Percent point function (inverse of `cdf`) at q of the given RV.
21+
Parameters
22+
----------
23+
q : array_like
24+
Lower tail probability.
25+
Returns
26+
-------
27+
k : array_like
28+
Quantile corresponding to the lower tail probability, q.
29+
"""
30+
return np.vectorize(self.datatype.get)(np.ceil(len(self.datatype) * q))
31+
32+
def cdf(self, k):
33+
"""
34+
Cumulative distribution function of the given RV.
35+
Parameters
36+
----------
37+
k : array_like
38+
quantiles
39+
Returns
40+
-------
41+
cdf : ndarray
42+
Cumulative distribution function evaluated at `x`
43+
"""
44+
return np.vectorize(self.inverse_dt.get)(k) / len(self.datatype)

causal_testing/json_front/json_class.py

Lines changed: 86 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from causal_testing.specification.causal_specification import CausalSpecification
2121
from causal_testing.specification.scenario import Scenario
2222
from causal_testing.specification.variable import Input, Meta, Output
23-
from causal_testing.testing.base_test_case import BaseTestCase
2423
from causal_testing.testing.causal_test_case import CausalTestCase
2524
from causal_testing.testing.causal_test_engine import CausalTestEngine
2625
from causal_testing.testing.estimators import Estimator
26+
from causal_testing.testing.base_test_case import BaseTestCase
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -41,7 +41,7 @@ class JsonUtility:
4141
:attr {Meta} metas: Causal variables representing metavariables.
4242
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
4343
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
44-
:attr {Scenario} modelling_scenario:
44+
:attr {Scenario} scenario:
4545
:attr {CausalSpecification} causal_specification:
4646
"""
4747

@@ -75,6 +75,32 @@ def setup(self, scenario: Scenario):
7575
self._json_parse()
7676
self._populate_metas()
7777

78+
def _create_abstract_test_case(self, test, mutates, effects):
79+
assert len(test["mutations"]) == 1
80+
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
81+
if not treatment_var.distribution:
82+
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
83+
fitter.fit()
84+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
85+
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
86+
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
87+
88+
abstract_test = AbstractCausalTestCase(
89+
scenario=self.scenario,
90+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
91+
treatment_variable=treatment_var,
92+
expected_causal_effect={
93+
self.scenario.variables[variable]: effects[effect]
94+
for variable, effect in test["expected_effect"].items()
95+
},
96+
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
97+
if "effect_modifiers" in test
98+
else {},
99+
estimate_type=test["estimate_type"],
100+
effect=test.get("effect", "total"),
101+
)
102+
return abstract_test
103+
78104
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
79105
"""Runs and evaluates each test case specified in the JSON input
80106
@@ -84,23 +110,51 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
84110
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
85111
"""
86112
failures = 0
113+
msg = ""
87114
for test in self.test_plan["tests"]:
88115
if "skip" in test and test["skip"]:
89116
continue
90117
test["estimator"] = estimators[test["estimator"]]
91118
if "mutations" in test:
92-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
93-
94-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
95-
failures = self._execute_tests(concrete_tests, test, f_flag)
96-
msg = (
97-
f"Executing test: {test['name']}\n"
98-
+ "abstract_test\n"
99-
+ f"{abstract_test}\n"
100-
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution}\n"
101-
+ f"Number of concrete tests for test case: {str(len(concrete_tests))}\n"
102-
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
103-
)
119+
if test["estimate_type"] == "coefficient":
120+
base_test_case = BaseTestCase(
121+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
122+
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
123+
effect=test.get("effect", "direct"),
124+
)
125+
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
126+
concrete_tests = [
127+
CausalTestCase(
128+
base_test_case=base_test_case,
129+
expected_causal_effect=next(
130+
effects[effect] for variable, effect in test["expected_effect"].items()
131+
),
132+
estimate_type="coefficient",
133+
effect_modifier_configuration={
134+
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
135+
},
136+
)
137+
]
138+
failures = self._execute_tests(concrete_tests, test, f_flag)
139+
msg = (
140+
f"Executing test: {test['name']} \n"
141+
+ f" {concrete_tests[0]} \n"
142+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
143+
)
144+
else:
145+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
146+
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
147+
failures = self._execute_tests(concrete_tests, test, f_flag)
148+
149+
msg = (
150+
f"Executing test: {test['name']} \n"
151+
+ " abstract_test \n"
152+
+ f" {abstract_test} \n"
153+
+ f" {abstract_test.treatment_variable.name},"
154+
+ f" {abstract_test.treatment_variable.distribution} \n"
155+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
156+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
157+
)
104158
self._append_to_file(msg, logging.INFO)
105159
else:
106160
outcome_variable = next(
@@ -132,24 +186,6 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
132186
)
133187
self._append_to_file(msg, logging.INFO)
134188

135-
def _create_abstract_test_case(self, test, mutates, effects):
136-
assert len(test["mutations"]) == 1
137-
abstract_test = AbstractCausalTestCase(
138-
scenario=self.scenario,
139-
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
140-
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
141-
expected_causal_effect={
142-
self.scenario.variables[variable]: effects[effect]
143-
for variable, effect in test["expected_effect"].items()
144-
},
145-
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
146-
if "effect_modifiers" in test
147-
else {},
148-
estimate_type=test["estimate_type"],
149-
effect=test.get("effect", "total"),
150-
)
151-
return abstract_test
152-
153189
def _execute_tests(self, concrete_tests, test, f_flag):
154190
failures = 0
155191
if "formula" in test:
@@ -175,13 +211,6 @@ def _populate_metas(self):
175211
"""
176212
for meta in self.scenario.variables_of_type(Meta):
177213
meta.populate(self.data)
178-
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
179-
if not var.distribution:
180-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
181-
fitter.fit()
182-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
183-
var.distribution = getattr(scipy.stats, dist)(**params)
184-
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
185214

186215
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
187216
"""Executes a singular test case, prints the results and returns the test case result
@@ -193,6 +222,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
193222
:rtype: bool
194223
"""
195224
failed = False
225+
226+
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
227+
if not var.distribution:
228+
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
229+
fitter.fit()
230+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
231+
var.distribution = getattr(scipy.stats, dist)(**params)
232+
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
233+
196234
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
197235
causal_test_result = causal_test_engine.execute_test(
198236
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
@@ -218,16 +256,23 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
218256
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
219257
return failed
220258

221-
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[CausalTestEngine, Estimator]:
259+
def _setup_test(
260+
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
261+
) -> tuple[CausalTestEngine, Estimator]:
222262
"""Create the necessary inputs for a single test case
223263
:param causal_test_case: The concrete test case to be executed
224264
:param test: Single JSON test definition stored in a mapping (dict)
265+
:param conditions: A list of conditions which should be applied to the
266+
data. Conditions should be in the query format detailed at
267+
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
225268
:returns:
226269
- causal_test_engine - Test Engine instance for the test being run
227270
- estimation_model - Estimator instance for the test being run
228271
"""
229272

230-
data_collector = ObservationalDataCollector(self.scenario, self.data)
273+
data_collector = ObservationalDataCollector(
274+
self.scenario, self.data.query(" & ".join(conditions)) if conditions else self.data
275+
)
231276
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
232277

233278
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)

causal_testing/specification/causal_dag.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,5 +521,14 @@ def identification(self, base_test_case: BaseTestCase):
521521
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
522522
return minimal_adjustment_set
523523

524+
def to_dot_string(self) -> str:
525+
"""Return a string of the DOT representation of the causal DAG.
526+
:return DOT string of the DAG.
527+
"""
528+
dotstring = "digraph G {\n"
529+
dotstring += "".join([f"{a} -> {b};\n" for a, b in self.graph.edges])
530+
dotstring += "}"
531+
return dotstring
532+
524533
def __str__(self):
525534
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"

causal_testing/specification/metamorphic_relation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
102102
def assertion(self, source_output, follow_up_output):
103103
"""An assertion that should be applied to an individual metamorphic test run."""
104104

105+
@abstractmethod
106+
def to_json_stub(self, skip=True) -> dict:
107+
"""Convert to a JSON frontend stub string for user customisation"""
108+
105109
@abstractmethod
106110
def test_oracle(self, test_results):
107111
"""A test oracle that assert whether the MR holds or not based on ALL test results.
@@ -129,6 +133,18 @@ def test_oracle(self, test_results):
129133
self.tests
130134
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
131135

136+
def to_json_stub(self, skip=True) -> dict:
137+
"""Convert to a JSON frontend stub string for user customisation"""
138+
return {
139+
"name": str(self),
140+
"estimator": "LinearRegressionEstimator",
141+
"estimate_type": "coefficient",
142+
"effect": "direct",
143+
"mutations": [self.treatment_var],
144+
"expected_effect": {self.output_var: "SomeEffect"},
145+
"skip": skip,
146+
}
147+
132148
def __str__(self):
133149
formatted_str = f"{self.treatment_var} --> {self.output_var}"
134150
if self.adjustment_vars:
@@ -149,6 +165,18 @@ def test_oracle(self, test_results):
149165
len(test_results["fail"]) == 0
150166
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
151167

168+
def to_json_stub(self, skip=True) -> dict:
169+
"""Convert to a JSON frontend stub string for user customisation"""
170+
return {
171+
"name": str(self),
172+
"estimator": "LinearRegressionEstimator",
173+
"estimate_type": "coefficient",
174+
"effect": "direct",
175+
"mutations": [self.treatment_var],
176+
"expected_effect": {self.output_var: "NoEffect"},
177+
"skip": skip,
178+
}
179+
152180
def __str__(self):
153181
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
154182
if self.adjustment_vars:

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
base_test_case: BaseTestCase,
2929
expected_causal_effect: CausalTestOutcome,
30-
control_value: Any,
30+
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
3333
effect_modifier_configuration: dict[Variable:Any] = None,

causal_testing/testing/causal_test_engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174174
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175175
confidence_intervals=confidence_intervals,
176176
)
177+
elif estimate_type == "coefficient":
178+
logger.debug("calculating coefficient")
179+
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180+
causal_test_result = CausalTestResult(
181+
estimator=estimator,
182+
test_value=TestValue("coefficient", coefficient),
183+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184+
confidence_intervals=confidence_intervals,
185+
)
177186
elif estimate_type == "ate":
178187
logger.debug("calculating ate")
179188
ate, confidence_intervals = estimator.estimate_ate()

0 commit comments

Comments
 (0)