Skip to content

Commit 7c185b4

Browse files
Merge branch 'main' into json_concrete_param
# Conflicts: # causal_testing/json_front/json_class.py
2 parents a0defb2 + d673d0a commit 7c185b4

File tree

25 files changed

+382
-156
lines changed

25 files changed

+382
-156
lines changed

.github/workflows/ci-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
run: |
2525
conda install -c conda-forge pygraphviz
2626
python --version
27-
pip install -e .
27+
pip install -e . --no-cache-dir
2828
pip install -e .[test]
2929
pip install pytest pytest-cov
3030
shell: bash -l {0}

causal_testing/json_front/json_class.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
5858

59-
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
59+
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6060
"""
6161
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6262
:param json_path: string path representation to .json file containing test specifications
6363
:param dag_path: string path representation to the .dot file containing the Causal DAG
6464
:param data_paths: string path representation to the data files
6565
"""
66+
if data_paths is None:
67+
data_paths = []
6668
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6769

6870
def setup(self, scenario: Scenario):
@@ -73,7 +75,17 @@ def setup(self, scenario: Scenario):
7375
self.causal_specification = CausalSpecification(
7476
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
7577
)
76-
self._json_parse()
78+
# Parse the JSON test plan
79+
with open(self.input_paths.json_path, encoding="utf-8") as f:
80+
self.test_plan = json.load(f)
81+
# Populate the data
82+
if self.input_paths.data_paths:
83+
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
84+
if len(self.data) == 0:
85+
raise ValueError(
86+
"No data found, either provide a path to a file containing data or manually populate the .data "
87+
"attribute with a dataframe before calling .setup()"
88+
)
7789
self._populate_metas()
7890

7991
def _create_abstract_test_case(self, test, mutates, effects):
@@ -146,6 +158,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
146158
+ f"control value = {test['control_value']}, treatment value = {test['treatment_value']} \n"
147159
+ f"Result: {'FAILED' if failed else 'Passed'}"
148160
)
161+
print(msg)
149162
self._append_to_file(msg, logging.INFO)
150163

151164
def _run_coefficient_test(self, test: dict, f_flag: bool, effects: dict):
@@ -225,15 +238,6 @@ def _execute_tests(self, concrete_tests, test, f_flag):
225238
failures += 1
226239
return failures, details
227240

228-
def _json_parse(self):
229-
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
230-
with open(self.input_paths.json_path, encoding="utf-8") as f:
231-
self.test_plan = json.load(f)
232-
for data_file in self.input_paths.data_paths:
233-
df = pd.read_csv(data_file, header=0)
234-
self.data.append(df)
235-
self.data = pd.concat(self.data)
236-
237241
def _populate_metas(self):
238242
"""
239243
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
@@ -257,13 +261,11 @@ def _execute_test_case(
257261
causal_test_engine, estimation_model = self._setup_test(
258262
causal_test_case, test, test["conditions"] if "conditions" in test else None
259263
)
260-
causal_test_result = causal_test_engine.execute_test(
261-
estimation_model, causal_test_case, estimate_type=causal_test_case.estimate_type
262-
)
264+
causal_test_result = causal_test_engine.execute_test(estimation_model, causal_test_case)
263265

264266
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
265267

266-
if causal_test_result.ci_low() and causal_test_result.ci_high():
268+
if causal_test_result.ci_low() is not None and causal_test_result.ci_high() is not None:
267269
result_string = (
268270
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
269271
f"{causal_test_result.ci_high()}"
@@ -378,7 +380,6 @@ def get_args(test_args=None) -> argparse.Namespace:
378380
parser.add_argument(
379381
"--data_path",
380382
help="Specify path to file containing runtime data",
381-
required=True,
382383
nargs="+",
383384
)
384385
parser.add_argument(

causal_testing/specification/metamorphic_relation.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
from abc import abstractmethod
88
from typing import Iterable
99
from itertools import combinations
10-
import numpy as np
11-
import pandas as pd
10+
import argparse
11+
import logging
12+
import json
1213
import networkx as nx
14+
import pandas as pd
15+
import numpy as np
1316

1417
from causal_testing.specification.causal_specification import CausalDAG, Node
1518
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1619

20+
logger = logging.getLogger(__name__)
21+
1722

1823
@dataclass(order=True)
1924
class MetamorphicRelation:
@@ -142,6 +147,7 @@ def to_json_stub(self, skip=True) -> dict:
142147
"effect": "direct",
143148
"mutations": [self.treatment_var],
144149
"expected_effect": {self.output_var: "SomeEffect"},
150+
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
145151
"skip": skip,
146152
}
147153

@@ -174,6 +180,7 @@ def to_json_stub(self, skip=True) -> dict:
174180
"effect": "direct",
175181
"mutations": [self.treatment_var],
176182
"expected_effect": {self.output_var: "NoEffect"},
183+
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
177184
"skip": skip,
178185
}
179186

@@ -244,3 +251,35 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
244251
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
245252

246253
return metamorphic_relations
254+
255+
256+
if __name__ == "__main__": # pragma: no cover
257+
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
258+
parser = argparse.ArgumentParser(
259+
description="A script for generating metamorphic relations to test the causal relationships in a given DAG."
260+
)
261+
parser.add_argument(
262+
"--dag_path",
263+
"-d",
264+
help="Specify path to file containing the DAG, normally a .dot file.",
265+
required=True,
266+
)
267+
parser.add_argument(
268+
"--output_path",
269+
"-o",
270+
help="Specify path where tests should be saved, normally a .json file.",
271+
required=True,
272+
)
273+
args = parser.parse_args()
274+
275+
causal_dag = CausalDAG(args.dag_path)
276+
relations = generate_metamorphic_relations(causal_dag)
277+
tests = [
278+
relation.to_json_stub(skip=False)
279+
for relation in relations
280+
if len(list(causal_dag.graph.predecessors(relation.output_var))) > 0
281+
]
282+
283+
logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.")
284+
with open(args.output_path, "w", encoding="utf-8") as f:
285+
json.dump({"tests": tests}, f, indent=2)

causal_testing/testing/causal_test_engine.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8181

8282
estimators = test_suite[edge]["estimators"]
8383
tests = test_suite[edge]["tests"]
84-
estimate_type = test_suite[edge]["estimate_type"]
8584
results = {}
8685
for estimator_class in estimators:
8786
causal_test_results = []
@@ -96,16 +95,14 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
9695
)
9796
if estimator.df is None:
9897
estimator.df = self.scenario_execution_data_df
99-
causal_test_result = self._return_causal_test_results(estimate_type, estimator, test)
98+
causal_test_result = self._return_causal_test_results(estimator, test)
10099
causal_test_results.append(causal_test_result)
101100

102101
results[estimator_class.__name__] = causal_test_results
103102
test_suite_results[edge] = results
104103
return test_suite_results
105104

106-
def execute_test(
107-
self, estimator: type(Estimator), causal_test_case: CausalTestCase, estimate_type: str = "ate"
108-
) -> CausalTestResult:
105+
def execute_test(self, estimator: type(Estimator), causal_test_case: CausalTestCase) -> CausalTestResult:
109106
"""Execute a causal test case and return the causal test result.
110107
111108
Test case execution proceeds with the following steps:
@@ -120,7 +117,6 @@ def execute_test(
120117
121118
:param estimator: A reference to an Estimator class.
122119
:param causal_test_case: The CausalTestCase object to be tested
123-
:param estimate_type: A string which denotes the type of estimate to return, ATE or CATE.
124120
:return causal_test_result: A CausalTestResult for the executed causal test case.
125121
"""
126122
if self.scenario_execution_data_df.empty:
@@ -142,18 +138,17 @@ def execute_test(
142138
if self._check_positivity_violation(variables_for_positivity):
143139
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")
144140

145-
causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
141+
causal_test_result = self._return_causal_test_results(estimator, causal_test_case)
146142
return causal_test_result
147143

148-
def _return_causal_test_results(self, estimate_type, estimator, causal_test_case):
144+
def _return_causal_test_results(self, estimator, causal_test_case):
149145
"""Depending on the estimator used, calculate the 95% confidence intervals and return in a causal_test_result
150146
151-
:param estimate_type: A string which denotes the type of estimate to return
152147
:param estimator: An Estimator class object
153148
:param causal_test_case: The concrete test case to be executed
154149
:return: a CausalTestResult object containing the confidence intervals
155150
"""
156-
if estimate_type == "cate":
151+
if causal_test_case.estimate_type == "cate":
157152
logger.debug("calculating cate")
158153
if not hasattr(estimator, "estimate_cates"):
159154
raise NotImplementedError(f"{estimator.__class__} has no CATE method.")
@@ -165,7 +160,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
165160
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
166161
confidence_intervals=confidence_intervals,
167162
)
168-
elif estimate_type == "risk_ratio":
163+
elif causal_test_case.estimate_type == "risk_ratio":
169164
logger.debug("calculating risk_ratio")
170165
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
171166
causal_test_result = CausalTestResult(
@@ -174,7 +169,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
174169
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
175170
confidence_intervals=confidence_intervals,
176171
)
177-
elif estimate_type == "coefficient":
172+
elif causal_test_case.estimate_type == "coefficient":
178173
logger.debug("calculating coefficient")
179174
coefficient, confidence_intervals = estimator.estimate_unit_ate()
180175
causal_test_result = CausalTestResult(
@@ -183,7 +178,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
183178
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
184179
confidence_intervals=confidence_intervals,
185180
)
186-
elif estimate_type == "ate":
181+
elif causal_test_case.estimate_type == "ate":
187182
logger.debug("calculating ate")
188183
ate, confidence_intervals = estimator.estimate_ate()
189184
causal_test_result = CausalTestResult(
@@ -194,7 +189,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
194189
)
195190
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
196191
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
197-
elif estimate_type == "ate_calculated":
192+
elif causal_test_case.estimate_type == "ate_calculated":
198193
logger.debug("calculating ate")
199194
ate, confidence_intervals = estimator.estimate_ate_calculated()
200195
causal_test_result = CausalTestResult(
@@ -206,7 +201,9 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
206201
# causal_test_result = CausalTestResult(minimal_adjustment_set, ate, confidence_intervals)
207202
# causal_test_result.apply_test_oracle_procedure(self.causal_test_case.expected_causal_effect)
208203
else:
209-
raise ValueError(f"Invalid estimate type {estimate_type}, expected 'ate', 'cate', or 'risk_ratio'")
204+
raise ValueError(
205+
f"Invalid estimate type {causal_test_case.estimate_type}, expected 'ate', 'cate', or 'risk_ratio'"
206+
)
210207
return causal_test_result
211208

212209
def _check_positivity_violation(self, variables_list):

causal_testing/testing/causal_test_outcome.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
44

55
from abc import ABC, abstractmethod
6+
from collections.abc import Iterable
67
import numpy as np
78

89
from causal_testing.testing.causal_test_result import CausalTestResult
@@ -26,8 +27,12 @@ class SomeEffect(CausalTestOutcome):
2627
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2728

2829
def apply(self, res: CausalTestResult) -> bool:
29-
if res.test_value.type in {"ate", "coefficient"}:
30+
if res.test_value.type == "ate":
3031
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
32+
if res.test_value.type == "coefficient":
33+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
34+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
35+
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
3136
if res.test_value.type == "risk_ratio":
3237
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3338
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
@@ -36,32 +41,41 @@ def apply(self, res: CausalTestResult) -> bool:
3641
class NoEffect(CausalTestOutcome):
3742
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
3843

44+
def __init__(self, atol: float = 1e-10):
45+
self.atol = atol
46+
3947
def apply(self, res: CausalTestResult) -> bool:
40-
print("RESULT", res)
41-
if res.test_value.type in {"ate", "coefficient"}:
42-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
48+
if res.test_value.type == "ate":
49+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)
50+
if res.test_value.type == "coefficient":
51+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
52+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
53+
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
54+
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
55+
abs(v) < self.atol for v in value
56+
)
4357
if res.test_value.type == "risk_ratio":
44-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
58+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)
4559
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
4660

4761

4862
class ExactValue(SomeEffect):
4963
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
5064

51-
def __init__(self, value: float, tolerance: float = None):
65+
def __init__(self, value: float, atol: float = None):
5266
self.value = value
53-
if tolerance is None:
54-
self.tolerance = value * 0.05
67+
if atol is None:
68+
self.atol = value * 0.05
5569
else:
56-
self.tolerance = tolerance
70+
self.atol = atol
5771

5872
def apply(self, res: CausalTestResult) -> bool:
5973
if res.ci_valid():
60-
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.tolerance)
61-
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
74+
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.atol)
75+
return np.isclose(res.test_value.value, self.value, atol=self.atol)
6276

6377
def __str__(self):
64-
return f"ExactValue: {self.value}±{self.tolerance}"
78+
return f"ExactValue: {self.value}±{self.atol}"
6579

6680

6781
class Positive(SomeEffect):
@@ -74,6 +88,7 @@ def apply(self, res: CausalTestResult) -> bool:
7488
return res.test_value.value > 0
7589
if res.test_value.type == "risk_ratio":
7690
return res.test_value.value > 1
91+
# Dead code but necessary for pylint
7792
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7893

7994

@@ -87,4 +102,5 @@ def apply(self, res: CausalTestResult) -> bool:
87102
return res.test_value.value < 0
88103
if res.test_value.type == "risk_ratio":
89104
return res.test_value.value < 1
105+
# Dead code but necessary for pylint
90106
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

0 commit comments

Comments
 (0)