Skip to content

Commit 75a9310

Browse files
Merge branch 'main' into dev_docs
2 parents fc7ab48 + bc13d84 commit 75a9310

25 files changed

+463
-190
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def filter_valid_data(self, data: pd.DataFrame, check_pos: bool = True) -> pd.Da
6161
self.scenario.variables[var].z3
6262
== self.scenario.variables[var].z3_val(self.scenario.variables[var].z3, row[var])
6363
for var in self.scenario.variables
64-
if var in row
64+
if var in row and not pd.isnull(row[var])
6565
]
6666
for c in model:
6767
solver.assert_and_track(c, f"model: {c}")
@@ -147,7 +147,8 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
147147

148148
execution_data_df = self.data
149149
for meta in self.scenario.metas():
150-
meta.populate(execution_data_df)
150+
if meta.name not in self.data:
151+
meta.populate(execution_data_df)
151152
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
152153
for var_name, var in self.scenario.variables.items():
153154
if issubclass(var.datatype, Enum):

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _generate_concrete_tests(
131131
)
132132

133133
for v in self.scenario.inputs():
134-
if row[v.name] != v.cast(model[v.z3]):
134+
if v.name in row and row[v.name] != v.cast(model[v.z3]):
135135
constraints = "\n ".join([str(c) for c in self.scenario.constraints if v.name in str(c)])
136136
logger.warning(
137137
f"Unable to set variable {v.name} to {row[v.name]} because of constraints\n"

causal_testing/generation/enum_gen.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""This module contains the class EnumGen, which allows us to easily create
2+
generating uniform distributions from enums."""
3+
4+
from enum import Enum
5+
from scipy.stats import rv_discrete
6+
import numpy as np
7+
8+
9+
class EnumGen(rv_discrete):
10+
"""This class allows us to easily create generating uniform distributions
11+
from enums. This is helpful for generating concrete test inputs from
12+
abstract test cases."""
13+
14+
def __init__(self, datatype: Enum):
15+
super().__init__()
16+
self.datatype = dict(enumerate(datatype, 1))
17+
self.inverse_dt = {v: k for k, v in self.datatype.items()}
18+
19+
def ppf(self, q):
20+
"""Percent point function (inverse of `cdf`) at q of the given RV.
21+
Parameters
22+
----------
23+
q : array_like
24+
Lower tail probability.
25+
Returns
26+
-------
27+
k : array_like
28+
Quantile corresponding to the lower tail probability, q.
29+
"""
30+
return np.vectorize(self.datatype.get)(np.ceil(len(self.datatype) * q))
31+
32+
def cdf(self, k):
33+
"""
34+
Cumulative distribution function of the given RV.
35+
Parameters
36+
----------
37+
k : array_like
38+
quantiles
39+
Returns
40+
-------
41+
cdf : ndarray
42+
Cumulative distribution function evaluated at `x`
43+
"""
44+
return np.vectorize(self.inverse_dt.get)(k) / len(self.datatype)

causal_testing/json_front/json_class.py

Lines changed: 95 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from causal_testing.specification.causal_specification import CausalSpecification
2121
from causal_testing.specification.scenario import Scenario
2222
from causal_testing.specification.variable import Input, Meta, Output
23-
from causal_testing.testing.base_test_case import BaseTestCase
2423
from causal_testing.testing.causal_test_case import CausalTestCase
24+
from causal_testing.testing.causal_test_result import CausalTestResult
2525
from causal_testing.testing.causal_test_engine import CausalTestEngine
2626
from causal_testing.testing.estimators import Estimator
27+
from causal_testing.testing.base_test_case import BaseTestCase
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -41,7 +42,7 @@ class JsonUtility:
4142
:attr {Meta} metas: Causal variables representing metavariables.
4243
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
4344
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
44-
:attr {Scenario} modelling_scenario:
45+
:attr {Scenario} scenario:
4546
:attr {CausalSpecification} causal_specification:
4647
"""
4748

@@ -75,6 +76,33 @@ def setup(self, scenario: Scenario):
7576
self._json_parse()
7677
self._populate_metas()
7778

79+
def _create_abstract_test_case(self, test, mutates, effects):
80+
assert len(test["mutations"]) == 1
81+
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
82+
83+
if not treatment_var.distribution:
84+
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
85+
fitter.fit()
86+
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
87+
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
88+
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
89+
90+
abstract_test = AbstractCausalTestCase(
91+
scenario=self.scenario,
92+
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
93+
treatment_variable=treatment_var,
94+
expected_causal_effect={
95+
self.scenario.variables[variable]: effects[effect]
96+
for variable, effect in test["expected_effect"].items()
97+
},
98+
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
99+
if "effect_modifiers" in test
100+
else {},
101+
estimate_type=test["estimate_type"],
102+
effect=test.get("effect", "total"),
103+
)
104+
return abstract_test
105+
78106
def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False, mutates: dict = None):
79107
"""Runs and evaluates each test case specified in the JSON input
80108
@@ -84,23 +112,52 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
84112
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
85113
"""
86114
failures = 0
115+
msg = ""
87116
for test in self.test_plan["tests"]:
88117
if "skip" in test and test["skip"]:
89118
continue
90119
test["estimator"] = estimators[test["estimator"]]
91120
if "mutations" in test:
92-
abstract_test = self._create_abstract_test_case(test, mutates, effects)
93-
94-
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
95-
failures = self._execute_tests(concrete_tests, test, f_flag)
96-
msg = (
97-
f"Executing test: {test['name']}\n"
98-
+ "abstract_test\n"
99-
+ f"{abstract_test}\n"
100-
+ f"{abstract_test.treatment_variable.name},{abstract_test.treatment_variable.distribution}\n"
101-
+ f"Number of concrete tests for test case: {str(len(concrete_tests))}\n"
102-
+ f"{failures}/{len(concrete_tests)} failed for {test['name']}"
103-
)
121+
if test["estimate_type"] == "coefficient":
122+
base_test_case = BaseTestCase(
123+
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
124+
outcome_variable=next(self.scenario.variables[v] for v in test["expected_effect"]),
125+
effect=test.get("effect", "direct"),
126+
)
127+
assert len(test["expected_effect"]) == 1, "Can only have one expected effect."
128+
causal_test_case = CausalTestCase(
129+
base_test_case=base_test_case,
130+
expected_causal_effect=next(
131+
effects[effect] for variable, effect in test["expected_effect"].items()
132+
),
133+
estimate_type="coefficient",
134+
effect_modifier_configuration={
135+
self.scenario.variables[v] for v in test.get("effect_modifiers", [])
136+
},
137+
)
138+
result = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
139+
msg = (
140+
f"Executing test: {test['name']} \n"
141+
+ f" {causal_test_case} \n"
142+
+ " "
143+
+ ("\n ").join(str(result[1]).split("\n"))
144+
+ "==============\n"
145+
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
146+
)
147+
else:
148+
abstract_test = self._create_abstract_test_case(test, mutates, effects)
149+
concrete_tests, _ = abstract_test.generate_concrete_tests(5, 0.05)
150+
failures, _ = self._execute_tests(concrete_tests, test, f_flag)
151+
152+
msg = (
153+
f"Executing test: {test['name']} \n"
154+
+ " abstract_test \n"
155+
+ f" {abstract_test} \n"
156+
+ f" {abstract_test.treatment_variable.name},"
157+
+ f" {abstract_test.treatment_variable.distribution} \n"
158+
+ f" Number of concrete tests for test case: {str(len(concrete_tests))} \n"
159+
+ f" {failures}/{len(concrete_tests)} failed for {test['name']}"
160+
)
104161
self._append_to_file(msg, logging.INFO)
105162
else:
106163
outcome_variable = next(
@@ -118,47 +175,28 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
118175
treatment_value=test["treatment_value"],
119176
estimate_type=test["estimate_type"],
120177
)
121-
if self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag):
122-
result = "failed"
123-
else:
124-
result = "passed"
178+
failed, _ = self._execute_test_case(causal_test_case=causal_test_case, test=test, f_flag=f_flag)
125179

126180
msg = (
127181
f"Executing concrete test: {test['name']} \n"
128182
+ f"treatment variable: {test['treatment_variable']} \n"
129183
+ f"outcome_variable = {outcome_variable} \n"
130184
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
131-
+ f"result - {result}"
185+
+ f"Result: {'FAILED' if failed else 'Passed'}"
132186
)
133187
self._append_to_file(msg, logging.INFO)
134188

135-
def _create_abstract_test_case(self, test, mutates, effects):
136-
assert len(test["mutations"]) == 1
137-
abstract_test = AbstractCausalTestCase(
138-
scenario=self.scenario,
139-
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
140-
treatment_variable=next(self.scenario.variables[v] for v in test["mutations"]),
141-
expected_causal_effect={
142-
self.scenario.variables[variable]: effects[effect]
143-
for variable, effect in test["expected_effect"].items()
144-
},
145-
effect_modifiers={self.scenario.variables[v] for v in test["effect_modifiers"]}
146-
if "effect_modifiers" in test
147-
else {},
148-
estimate_type=test["estimate_type"],
149-
effect=test.get("effect", "total"),
150-
)
151-
return abstract_test
152-
153189
def _execute_tests(self, concrete_tests, test, f_flag):
154190
failures = 0
191+
details = []
155192
if "formula" in test:
156193
self._append_to_file(f"Estimator formula used for test: {test['formula']}")
157194
for concrete_test in concrete_tests:
158-
failed = self._execute_test_case(concrete_test, test, f_flag)
195+
failed, result = self._execute_test_case(concrete_test, test, f_flag)
196+
details.append(result)
159197
if failed:
160198
failures += 1
161-
return failures
199+
return failures, details
162200

163201
def _json_parse(self):
164202
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
@@ -175,15 +213,10 @@ def _populate_metas(self):
175213
"""
176214
for meta in self.scenario.variables_of_type(Meta):
177215
meta.populate(self.data)
178-
for var in self.scenario.variables_of_type(Meta).union(self.scenario.variables_of_type(Output)):
179-
if not var.distribution:
180-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
181-
fitter.fit()
182-
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
183-
var.distribution = getattr(scipy.stats, dist)(**params)
184-
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
185-
186-
def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool) -> bool:
216+
217+
def _execute_test_case(
218+
self, causal_test_case: CausalTestCase, test: Iterable[Mapping], f_flag: bool
219+
) -> (bool, CausalTestResult):
187220
"""Executes a singular test case, prints the results and returns the test case result
188221
:param causal_test_case: The concrete test case to be executed
189222
:param test: Single JSON test definition stored in a mapping (dict)
@@ -193,7 +226,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
193226
:rtype: bool
194227
"""
195228
failed = False
196-
causal_test_engine, estimation_model = self._setup_test(causal_test_case, test)
229+
230+
causal_test_engine, estimation_model = self._setup_test(
231+
causal_test_case, test, test["conditions"] if "conditions" in test else None
232+
)
197233
causal_test_result = causal_test_engine.execute_test(
198234
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
199235
)
@@ -216,18 +252,25 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
216252
)
217253
failed = True
218254
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
219-
return failed
255+
return failed, causal_test_result
220256

221-
def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[CausalTestEngine, Estimator]:
257+
def _setup_test(
258+
self, causal_test_case: CausalTestCase, test: Mapping, conditions: list[str] = None
259+
) -> tuple[CausalTestEngine, Estimator]:
222260
"""Create the necessary inputs for a single test case
223261
:param causal_test_case: The concrete test case to be executed
224262
:param test: Single JSON test definition stored in a mapping (dict)
263+
:param conditions: A list of conditions which should be applied to the
264+
data. Conditions should be in the query format detailed at
265+
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
225266
:returns:
226267
- causal_test_engine - Test Engine instance for the test being run
227268
- estimation_model - Estimator instance for the test being run
228269
"""
229270

230-
data_collector = ObservationalDataCollector(self.scenario, self.data)
271+
data_collector = ObservationalDataCollector(
272+
self.scenario, self.data.query(" & ".join(conditions)) if conditions else self.data
273+
)
231274
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
232275

233276
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)

causal_testing/specification/causal_dag.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,5 +521,14 @@ def identification(self, base_test_case: BaseTestCase):
521521
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
522522
return minimal_adjustment_set
523523

524+
def to_dot_string(self) -> str:
525+
"""Return a string of the DOT representation of the causal DAG.
526+
:return DOT string of the DAG.
527+
"""
528+
dotstring = "digraph G {\n"
529+
dotstring += "".join([f"{a} -> {b};\n" for a, b in self.graph.edges])
530+
dotstring += "}"
531+
return dotstring
532+
524533
def __str__(self):
525534
return f"Nodes: {self.graph.nodes}\nEdges: {self.graph.edges}"

causal_testing/specification/metamorphic_relation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
102102
def assertion(self, source_output, follow_up_output):
103103
"""An assertion that should be applied to an individual metamorphic test run."""
104104

105+
@abstractmethod
106+
def to_json_stub(self, skip=True) -> dict:
107+
"""Convert to a JSON frontend stub string for user customisation"""
108+
105109
@abstractmethod
106110
def test_oracle(self, test_results):
107111
"""A test oracle that assert whether the MR holds or not based on ALL test results.
@@ -129,6 +133,18 @@ def test_oracle(self, test_results):
129133
self.tests
130134
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
131135

136+
def to_json_stub(self, skip=True) -> dict:
137+
"""Convert to a JSON frontend stub string for user customisation"""
138+
return {
139+
"name": str(self),
140+
"estimator": "LinearRegressionEstimator",
141+
"estimate_type": "coefficient",
142+
"effect": "direct",
143+
"mutations": [self.treatment_var],
144+
"expected_effect": {self.output_var: "SomeEffect"},
145+
"skip": skip,
146+
}
147+
132148
def __str__(self):
133149
formatted_str = f"{self.treatment_var} --> {self.output_var}"
134150
if self.adjustment_vars:
@@ -149,6 +165,18 @@ def test_oracle(self, test_results):
149165
len(test_results["fail"]) == 0
150166
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
151167

168+
def to_json_stub(self, skip=True) -> dict:
169+
"""Convert to a JSON frontend stub string for user customisation"""
170+
return {
171+
"name": str(self),
172+
"estimator": "LinearRegressionEstimator",
173+
"estimate_type": "coefficient",
174+
"effect": "direct",
175+
"mutations": [self.treatment_var],
176+
"expected_effect": {self.output_var: "NoEffect"},
177+
"skip": skip,
178+
}
179+
152180
def __str__(self):
153181
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
154182
if self.adjustment_vars:

causal_testing/testing/causal_test_case.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self,
2828
base_test_case: BaseTestCase,
2929
expected_causal_effect: CausalTestOutcome,
30-
control_value: Any,
30+
control_value: Any = None,
3131
treatment_value: Any = None,
3232
estimate_type: str = "ate",
3333
effect_modifier_configuration: dict[Variable:Any] = None,

causal_testing/testing/causal_test_engine.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174174
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175175
confidence_intervals=confidence_intervals,
176176
)
177+
elif estimate_type == "coefficient":
178+
logger.debug("calculating coefficient")
179+
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180+
causal_test_result = CausalTestResult(
181+
estimator=estimator,
182+
test_value=TestValue("coefficient", coefficient),
183+
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184+
confidence_intervals=confidence_intervals,
185+
)
177186
elif estimate_type == "ate":
178187
logger.debug("calculating ate")
179188
ate, confidence_intervals = estimator.estimate_ate()

0 commit comments

Comments
 (0)