Skip to content

Commit 9e2118b

Browse files
committed
tests pass
1 parent 7611771 commit 9e2118b

File tree

5 files changed

+42
-15
lines changed

5 files changed

+42
-15
lines changed

causal_testing/json_front/json_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
5858

59-
def set_paths(self, json_path: str, dag_path: str, data_paths: str=[]):
59+
def set_paths(self, json_path: str, dag_path: str, data_paths: str = []):
6060
"""
6161
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6262
:param json_path: string path representation to .json file containing test specifications

causal_testing/testing/causal_test_outcome.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
ExactValue, Positive, Negative, SomeEffect, NoEffect"""
44

55
from abc import ABC, abstractmethod
6+
from collections.abc import Iterable
67
import numpy as np
78

89
from causal_testing.testing.causal_test_result import CausalTestResult
@@ -26,9 +27,12 @@ class SomeEffect(CausalTestOutcome):
2627
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2728

2829
def apply(self, res: CausalTestResult) -> bool:
29-
if res.test_value.type in {"ate", "coefficient"}:
30-
return any([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high())])
31-
# return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
30+
if res.test_value.type == "ate":
31+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
32+
if res.test_value.type == "coefficient":
33+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
34+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
35+
return any([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high)])
3236
if res.test_value.type == "risk_ratio":
3337
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3438
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
@@ -38,10 +42,15 @@ class NoEffect(CausalTestOutcome):
3842
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
3943

4044
def apply(self, res: CausalTestResult, threshold: float = 1e-10) -> bool:
41-
print("RESULT", res)
42-
if res.test_value.type in {"ate", "coefficient"}:
43-
return all([ci_low < 0< ci_high for ci_low, ci_high in zip(res.ci_low(), res.ci_high())]) or all([abs(v) < 1e-10 for v in res.test_value.value])
44-
# return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
45+
if res.test_value.type == "ate":
46+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
47+
if res.test_value.type == "coefficient":
48+
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
49+
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
50+
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
51+
return all([ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)]) or all(
52+
[abs(v) < 1e-10 for v in value]
53+
)
4554
if res.test_value.type == "risk_ratio":
4655
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
4756
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/causal_test_result.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,26 @@ def __init__(
4444

4545
def __str__(self):
4646
def push(s, inc=" "):
47-
return inc + str(s).replace("\n", "\n"+inc)
47+
return inc + str(s).replace("\n", "\n" + inc)
48+
49+
result_str = " " + str(self.test_value.value)
50+
if "\n" in result_str:
51+
result_str = "\n" + push(self.test_value.value)
4852
base_str = (
4953
f"Causal Test Result\n==============\n"
5054
f"Treatment: {self.estimator.treatment}\n"
5155
f"Control value: {self.estimator.control_value}\n"
5256
f"Treatment value: {self.estimator.treatment_value}\n"
5357
f"Outcome: {self.estimator.outcome}\n"
5458
f"Adjustment set: {self.adjustment_set}\n"
55-
f"{self.test_value.type}:\n{push(self.test_value.value)}\n"
59+
f"{self.test_value.type}:{result_str}\n"
5660
)
5761
confidence_str = ""
5862
if self.confidence_intervals:
59-
confidence_str += f"Confidence intervals:\n{push(pd.DataFrame(self.confidence_intervals).transpose().to_string(header=False))}\n"
63+
ci_str = " " + str(self.confidence_intervals)
64+
if "\n" in ci_str:
65+
ci_str = " " + push(pd.DataFrame(self.confidence_intervals).transpose().to_string(header=False))
66+
confidence_str += f"Confidence intervals:{ci_str}\n"
6067
return base_str + confidence_str
6168

6269
def to_dict(self):

causal_testing/testing/estimators.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,22 @@ def estimate_unit_ate(self) -> float:
339339
print(model.conf_int())
340340
treatment = [self.treatment]
341341
if str(self.df.dtypes[self.treatment]) == "object":
342-
reference = min(self.df[self.treatment])
343-
treatment = [x.replace("[", "[T.") for x in dmatrix(f"{self.treatment}-1", self.df.query(f"{self.treatment} != '{reference}'"), return_type="dataframe").columns]
344-
assert set(treatment).issubset(model.params.index.tolist()), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"
342+
reference = min(self.df[self.treatment])
343+
treatment = [
344+
x.replace("[", "[T.")
345+
for x in dmatrix(
346+
f"{self.treatment}-1", self.df.query(f"{self.treatment} != '{reference}'"), return_type="dataframe"
347+
).columns
348+
]
349+
assert set(treatment).issubset(
350+
model.params.index.tolist()
351+
), f"{treatment} not in\n{' '+str(model.params.index).replace(newline, newline+' ')}"
345352
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
346353
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
354+
if str(self.df.dtypes[self.treatment]) != "object":
355+
unit_effect = unit_effect[0]
356+
ci_low = ci_low[0]
357+
ci_high = ci_high[0]
347358
return unit_effect, [ci_low, ci_high]
348359

349360
def estimate_ate(self) -> tuple[float, list[float, float], float]:
@@ -365,7 +376,6 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
365376

366377
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
367378
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
368-
print("t_test_results", t_test_results.effect)
369379
ate = t_test_results.effect[0]
370380
confidence_intervals = list(t_test_results.conf_int().flatten())
371381
return ate, confidence_intervals

tests/testing_tests/test_estimators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def test_program_11_2(self):
216216
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
217217

218218
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
219+
print("ATE", ate)
219220
self.assertEqual(round(model.params["treatments"], 1), round(ate, 1))
220221

221222
def test_program_11_3(self):

0 commit comments

Comments
 (0)