Skip to content

Commit 220b225

Browse files
Merge branch 'main' into dev_docs
2 parents c33f7cc + cabc194 commit 220b225

File tree

13 files changed

+194
-102
lines changed

13 files changed

+194
-102
lines changed

causal_testing/json_front/json_class.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
5858

59-
def set_paths(self, json_path: str, dag_path: str, data_paths: str):
59+
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6060
"""
6161
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6262
:param json_path: string path representation to .json file containing test specifications
6363
:param dag_path: string path representation to the .dot file containing the Causal DAG
6464
:param data_paths: string path representation to the data files
6565
"""
66+
if data_paths is None:
67+
data_paths = []
6668
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6769

6870
def setup(self, scenario: Scenario):
@@ -73,7 +75,12 @@ def setup(self, scenario: Scenario):
7375
self.causal_specification = CausalSpecification(
7476
scenario=self.scenario, causal_dag=CausalDAG(self.input_paths.dag_path)
7577
)
76-
self._json_parse()
78+
# Parse the JSON test plan
79+
with open(self.input_paths.json_path, encoding="utf-8") as f:
80+
self.test_plan = json.load(f)
81+
# Populate the data
82+
if self.input_paths.data_paths:
83+
self.data = pd.concat([pd.read_csv(data_file, header=0) for data_file in self.input_paths.data_paths])
7784
self._populate_metas()
7885

7986
def _create_abstract_test_case(self, test, mutates, effects):
@@ -144,6 +151,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
144151
+ "==============\n"
145152
+ f" Result: {'FAILED' if result[0] else 'Passed'}"
146153
)
154+
print(msg)
147155
else:
148156
abstract_test = self._create_abstract_test_case(test, mutates, effects)
149157
concrete_tests, _ = abstract_test.generate_concrete_tests(5, 0.05)
@@ -198,15 +206,6 @@ def _execute_tests(self, concrete_tests, test, f_flag):
198206
failures += 1
199207
return failures, details
200208

201-
def _json_parse(self):
202-
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
203-
with open(self.input_paths.json_path, encoding="utf-8") as f:
204-
self.test_plan = json.load(f)
205-
for data_file in self.input_paths.data_paths:
206-
df = pd.read_csv(data_file, header=0)
207-
self.data.append(df)
208-
self.data = pd.concat(self.data)
209-
210209
def _populate_metas(self):
211210
"""
212211
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
@@ -236,7 +235,7 @@ def _execute_test_case(
236235

237236
test_passes = causal_test_case.expected_causal_effect.apply(causal_test_result)
238237

239-
if causal_test_result.ci_low() and causal_test_result.ci_high():
238+
if causal_test_result.ci_low() is not None and causal_test_result.ci_high() is not None:
240239
result_string = (
241240
f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < "
242241
f"{causal_test_result.ci_high()}"
@@ -351,7 +350,6 @@ def get_args(test_args=None) -> argparse.Namespace:
351350
parser.add_argument(
352351
"--data_path",
353352
help="Specify path to file containing runtime data",
354-
required=True,
355353
nargs="+",
356354
)
357355
parser.add_argument(

causal_testing/specification/metamorphic_relation.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
from abc import abstractmethod
88
from typing import Iterable
99
from itertools import combinations
10-
import numpy as np
11-
import pandas as pd
10+
import argparse
11+
import logging
12+
import json
1213
import networkx as nx
14+
import pandas as pd
15+
import numpy as np
1316

1417
from causal_testing.specification.causal_specification import CausalDAG, Node
1518
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
1619

20+
logger = logging.getLogger(__name__)
21+
1722

1823
@dataclass(order=True)
1924
class MetamorphicRelation:
@@ -142,6 +147,7 @@ def to_json_stub(self, skip=True) -> dict:
142147
"effect": "direct",
143148
"mutations": [self.treatment_var],
144149
"expected_effect": {self.output_var: "SomeEffect"},
150+
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
145151
"skip": skip,
146152
}
147153

@@ -174,6 +180,7 @@ def to_json_stub(self, skip=True) -> dict:
174180
"effect": "direct",
175181
"mutations": [self.treatment_var],
176182
"expected_effect": {self.output_var: "NoEffect"},
183+
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
177184
"skip": skip,
178185
}
179186

@@ -244,3 +251,35 @@ def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
244251
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
245252

246253
return metamorphic_relations
254+
255+
256+
if __name__ == "__main__": # pragma: no cover
257+
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
258+
parser = argparse.ArgumentParser(
259+
description="A script for generating metamorphic relations to test the causal relationships in a given DAG."
260+
)
261+
parser.add_argument(
262+
"--dag_path",
263+
"-d",
264+
help="Specify path to file containing the DAG, normally a .dot file.",
265+
required=True,
266+
)
267+
parser.add_argument(
268+
"--output_path",
269+
"-o",
270+
help="Specify path where tests should be saved, normally a .json file.",
271+
required=True,
272+
)
273+
args = parser.parse_args()
274+
275+
causal_dag = CausalDAG(args.dag_path)
276+
relations = generate_metamorphic_relations(causal_dag)
277+
tests = [
278+
relation.to_json_stub(skip=False)
279+
for relation in relations
280+
if len(list(causal_dag.graph.predecessors(relation.output_var))) > 0
281+
]
282+
283+
logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.")
284+
with open(args.output_path, "w", encoding="utf-8") as f:
285+
json.dump({"tests": tests}, f, indent=2)

causal_testing/testing/causal_test_outcome.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
44

55
from abc import ABC, abstractmethod
6+
from collections.abc import Iterable
67
import numpy as np
78

89
from causal_testing.testing.causal_test_result import CausalTestResult
@@ -26,8 +27,12 @@ class SomeEffect(CausalTestOutcome):
2627
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2728

2829
def apply(self, res: CausalTestResult) -> bool:
29-
if res.test_value.type in {"ate", "coefficient"}:
30+
if res.test_value.type == "ate":
3031
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
32+
if res.test_value.type == "coefficient":
33+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
34+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
35+
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
3136
if res.test_value.type == "risk_ratio":
3237
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3338
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
@@ -36,32 +41,41 @@ def apply(self, res: CausalTestResult) -> bool:
3641
class NoEffect(CausalTestOutcome):
3742
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
3843

44+
def __init__(self, atol: float = 1e-10):
45+
self.atol = atol
46+
3947
def apply(self, res: CausalTestResult) -> bool:
40-
print("RESULT", res)
41-
if res.test_value.type in {"ate", "coefficient"}:
42-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
48+
if res.test_value.type == "ate":
49+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)
50+
if res.test_value.type == "coefficient":
51+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
52+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
53+
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
54+
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
55+
abs(v) < self.atol for v in value
56+
)
4357
if res.test_value.type == "risk_ratio":
44-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
58+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)
4559
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
4660

4761

4862
class ExactValue(SomeEffect):
4963
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
5064

51-
def __init__(self, value: float, tolerance: float = None):
65+
def __init__(self, value: float, atol: float = None):
5266
self.value = value
53-
if tolerance is None:
54-
self.tolerance = value * 0.05
67+
if atol is None:
68+
self.atol = value * 0.05
5569
else:
56-
self.tolerance = tolerance
70+
self.atol = atol
5771

5872
def apply(self, res: CausalTestResult) -> bool:
5973
if res.ci_valid():
60-
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.tolerance)
61-
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
74+
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.atol)
75+
return np.isclose(res.test_value.value, self.value, atol=self.atol)
6276

6377
def __str__(self):
64-
return f"ExactValue: {self.value}±{self.tolerance}"
78+
return f"ExactValue: {self.value}±{self.atol}"
6579

6680

6781
class Positive(SomeEffect):
@@ -74,6 +88,7 @@ def apply(self, res: CausalTestResult) -> bool:
7488
return res.test_value.value > 0
7589
if res.test_value.type == "risk_ratio":
7690
return res.test_value.value > 1
91+
# Dead code but necessary for pylint
7792
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7893

7994

@@ -87,4 +102,5 @@ def apply(self, res: CausalTestResult) -> bool:
87102
return res.test_value.value < 0
88103
if res.test_value.type == "risk_ratio":
89104
return res.test_value.value < 1
105+
# Dead code but necessary for pylint
90106
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/causal_test_result.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,27 @@ def __init__(
4343
self.effect_modifier_configuration = {}
4444

4545
def __str__(self):
46+
def push(s, inc=" "):
47+
return inc + str(s).replace("\n", "\n" + inc)
48+
49+
result_str = str(self.test_value.value)
50+
if "\n" in result_str:
51+
result_str = "\n" + push(self.test_value.value)
4652
base_str = (
4753
f"Causal Test Result\n==============\n"
4854
f"Treatment: {self.estimator.treatment}\n"
4955
f"Control value: {self.estimator.control_value}\n"
5056
f"Treatment value: {self.estimator.treatment_value}\n"
5157
f"Outcome: {self.estimator.outcome}\n"
5258
f"Adjustment set: {self.adjustment_set}\n"
53-
f"{self.test_value.type}: {self.test_value.value}\n"
59+
f"{self.test_value.type}: {result_str}\n"
5460
)
5561
confidence_str = ""
5662
if self.confidence_intervals:
57-
confidence_str += f"Confidence intervals: {self.confidence_intervals}\n"
63+
ci_str = " " + str(self.confidence_intervals)
64+
if "\n" in ci_str:
65+
ci_str = " " + push(pd.DataFrame(self.confidence_intervals).transpose().to_string(header=False))
66+
confidence_str += f"Confidence intervals:{ci_str}\n"
5867
return base_str + confidence_str
5968

6069
def to_dict(self):
@@ -76,14 +85,14 @@ def to_dict(self):
7685

7786
def ci_low(self):
7887
"""Return the lower bracket of the confidence intervals."""
79-
if self.confidence_intervals and all(self.confidence_intervals):
80-
return min(self.confidence_intervals)
88+
if self.confidence_intervals:
89+
return self.confidence_intervals[0]
8190
return None
8291

8392
def ci_high(self):
8493
"""Return the higher bracket of the confidence intervals."""
85-
if self.confidence_intervals and all(self.confidence_intervals):
86-
return max(self.confidence_intervals)
94+
if self.confidence_intervals:
95+
return self.confidence_intervals[1]
8796
return None
8897

8998
def ci_valid(self) -> bool:

causal_testing/testing/estimators.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,21 @@ def estimate_unit_ate(self) -> float:
335335
:return: The unit average treatment effect and the 95% Wald confidence intervals.
336336
"""
337337
model = self._run_linear_regression()
338-
assert self.treatment in model.params, f"{self.treatment} not in {model.params}"
339-
unit_effect = model.params[[self.treatment]].values[0] # Unit effect is the coefficient of the treatment
340-
[ci_low, ci_high] = self._get_confidence_intervals(model)
341-
338+
newline = "\n"
339+
print(model.conf_int())
340+
treatment = [self.treatment]
341+
if str(self.df.dtypes[self.treatment]) == "object":
342+
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
343+
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
344+
assert set(treatment).issubset(
345+
model.params.index.tolist()
346+
), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"
347+
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
348+
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
349+
if str(self.df.dtypes[self.treatment]) != "object":
350+
unit_effect = unit_effect[0]
351+
ci_low = ci_low[0]
352+
ci_high = ci_high[0]
342353
return unit_effect, [ci_low, ci_high]
343354

344355
def estimate_ate(self) -> tuple[float, list[float, float], float]:
@@ -353,12 +364,6 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
353364
# Create an empty individual for the control and treated
354365
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)
355366

356-
# This is a temporary hack
357-
# for t in self.square_terms:
358-
# individuals[t + "^2"] = individuals[t] ** 2
359-
# for a, b in self.product_terms:
360-
# individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
361-
362367
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
363368
# the effect with "ate = t_test_results.effect[0]"
364369
individuals.loc["control", [self.treatment]] = self.control_value
@@ -424,35 +429,6 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
424429

425430
return (treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
426431

427-
def estimate_cates(self) -> tuple[float, list[float, float]]:
428-
"""Estimate the conditional average treatment effect of the treatment on the outcome. That is, the change
429-
in outcome caused by changing the treatment variable from the control value to the treatment value.
430-
431-
:return: The conditional average treatment effect and the 95% Wald confidence intervals.
432-
"""
433-
assert (
434-
self.effect_modifiers
435-
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
436-
x = pd.DataFrame()
437-
x[self.treatment] = [self.treatment_value, self.control_value]
438-
x["Intercept"] = 1 # self.intercept
439-
for k, v in self.effect_modifiers.items():
440-
self.adjustment_set.add(k)
441-
x[k] = v
442-
if hasattr(self, "square_terms"):
443-
for t in self.square_terms:
444-
x[t + "^2"] = x[t] ** 2
445-
if hasattr(self, "product_terms"):
446-
for a, b in self.product_terms:
447-
x[f"{a}*{b}"] = x[a] * x[b]
448-
449-
model = self._run_linear_regression()
450-
y = model.predict(x)
451-
treatment_outcome = y.iloc[0]
452-
control_outcome = y.iloc[1]
453-
454-
return treatment_outcome - control_outcome, None
455-
456432
def _run_linear_regression(self) -> RegressionResultsWrapper:
457433
"""Run linear regression of the treatment and adjustment set against the outcome and return the model.
458434
@@ -472,22 +448,16 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
472448
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
473449
cols = [self.treatment]
474450
cols += [x for x in self.adjustment_set if x not in cols]
475-
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
476-
for col in treatment_and_adjustments_cols:
477-
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
478-
treatment_and_adjustments_cols = pd.get_dummies(
479-
treatment_and_adjustments_cols, columns=[col], drop_first=True
480-
)
481451
model = smf.ols(formula=self.formula, data=self.df).fit()
482452
return model
483453

484-
def _get_confidence_intervals(self, model):
454+
def _get_confidence_intervals(self, model, treatment):
485455
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
486456
ci_low, ci_high = (
487-
confidence_intervals[0][[self.treatment]],
488-
confidence_intervals[1][[self.treatment]],
457+
confidence_intervals[0].loc[treatment],
458+
confidence_intervals[1].loc[treatment],
489459
)
490-
return [ci_low.values[0], ci_high.values[0]]
460+
return [ci_low, ci_high]
491461

492462

493463
class InstrumentalVariableEstimator(Estimator):

causal_testing/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)