Skip to content

Commit 8f32552

Browse files
committed
Extra coverage
1 parent 147a22b commit 8f32552

File tree

6 files changed

+193
-34
lines changed

6 files changed

+193
-34
lines changed

causal_testing/specification/variable.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def typestring(self) -> str:
201201
"""
202202
return type(self).__name__
203203

204-
@abstractmethod
205204
def copy(self, name: str = None) -> Variable:
206205
"""Return a new instance of the Variable with the given name, or with
207206
the original name if no name is supplied.
@@ -211,26 +210,18 @@ def copy(self, name: str = None) -> Variable:
211210
:rtype: Variable
212211
213212
"""
214-
raise NotImplementedError("Method `copy` must be instantiated.")
213+
if name:
214+
return self.__class__(name, self.datatype, self.distribution)
215+
return self.__class__(self.name, self.datatype, self.distribution)
215216

216217

217218
class Input(Variable):
218219
"""An extension of the Variable class representing inputs."""
219220

220-
def copy(self, name=None) -> Input:
221-
if name:
222-
return Input(name, self.datatype, self.distribution)
223-
return Input(self.name, self.datatype, self.distribution)
224-
225221

226222
class Output(Variable):
227223
"""An extension of the Variable class representing outputs."""
228224

229-
def copy(self, name=None) -> Output:
230-
if name:
231-
return Output(name, self.datatype, self.distribution)
232-
return Output(self.name, self.datatype, self.distribution)
233-
234225

235226
class Meta(Variable):
236227
"""An extension of the Variable class representing metavariables. These are variables which are relevant to the
@@ -250,8 +241,3 @@ class Meta(Variable):
250241
def __init__(self, name: str, datatype: T, populate: Callable[[DataFrame], DataFrame]):
251242
super().__init__(name, datatype)
252243
self.populate = populate
253-
254-
def copy(self, name=None) -> Meta:
255-
if name:
256-
return Meta(name, self.datatype, self.distribution)
257-
return Meta(self.name, self.datatype, self.distribution)

causal_testing/testing/estimators.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
else:
5858
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
5959
self.modelling_assumptions = []
60+
self.add_modelling_assumptions()
6061
logger.debug("Effect Modifiers: %s", self.effect_modifiers)
6162

6263
@abstractmethod
@@ -403,18 +404,23 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
403404
confidence_intervals = list(t_test_results.conf_int().flatten())
404405
return ate, confidence_intervals
405406

406-
def estimate_control_treatment(self) -> tuple[pd.Series, pd.Series]:
407+
def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd.Series, pd.Series]:
407408
"""Estimate the outcomes under control and treatment.
408409
409410
:return: The estimated outcome under control and treatment in the form
410411
(control_outcome, treatment_outcome).
411412
"""
413+
if adjustment_config is None:
414+
adjustment_config = dict()
415+
412416
model = self._run_linear_regression()
413417
self.model = model
414418

415419
x = pd.DataFrame()
416420
x[self.treatment[0]] = [self.treatment_values, self.control_values]
417421
x["Intercept"] = self.intercept
422+
for k, v in adjustment_config.items():
423+
x[k] = v
418424
for k, v in self.effect_modifiers.items():
419425
x[k] = v
420426
for t in self.square_terms:
@@ -443,16 +449,15 @@ def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
443449

444450
return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
445451

446-
def estimate_ate_calculated(self) -> tuple[float, list[float, float]]:
452+
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
447453
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
448454
by changing the treatment variable from the control value to the treatment value. Here, we actually
449455
calculate the expected outcomes under control and treatment and divide one by the other. This
450456
allows for custom terms to be put in such as squares, inverses, products, etc.
451457
452458
:return: The average treatment effect and the 95% Wald confidence intervals.
453459
"""
454-
control_outcome, treatment_outcome = self.estimate_control_treatment()
455-
assert False
460+
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
456461
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
457462
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]
458463

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,54 @@
55
from causal_testing.specification.causal_specification import Scenario
66
from causal_testing.specification.variable import Input, Output, Meta
77
from scipy.stats import uniform, rv_discrete
8+
from enum import Enum
9+
import random
810
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
911

1012

1113
class TestObservationalDataCollector(unittest.TestCase):
1214
def setUp(self) -> None:
15+
class Color(Enum):
16+
RED = "RED"
17+
GREEN = "GREEN"
18+
BLUE = "BLUE"
19+
1320
temp_dir_path = create_temp_dir_if_non_existent()
1421
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
1522
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
1623
# Y = 3*X1 + X2*X3 + 10
17-
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
18-
self.observational_df["Y"] = self.observational_df.apply(
24+
self.observational_df = pd.DataFrame(
25+
{"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40], "Y2": ["RED", "GREEN", "BLUE", "BLUE"]}
26+
)
27+
self.observational_df["Y1"] = self.observational_df.apply(
1928
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1
2029
)
2130
self.observational_df.to_csv(self.observational_df_path)
31+
self.observational_df["Y2"] = [Color[x] for x in self.observational_df["Y2"]]
2232
self.X1 = Input("X1", int, uniform(1, 4))
2333
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
2434
self.X3 = Input("X3", int, uniform(10, 40))
2535
self.X4 = Input("X4", int, rv_discrete(values=([10], [1])))
26-
self.Y = Output("Y", int)
36+
self.Y1 = Output("Y1", int)
37+
self.Y2 = Output("Y2", Color)
2738

2839
def test_not_all_variables_in_data(self):
2940
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
3041
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
3142
self.assertRaises(IndexError, observational_data_collector.collect_data)
3243

3344
def test_all_variables_in_data(self):
34-
scenario = Scenario({self.X1, self.X2, self.X3, self.Y})
45+
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2})
3546
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
3647
df = observational_data_collector.collect_data(index_col=0)
37-
assert df.equals(self.observational_df), f"{df}\nwas not equal to\n{self.observational_df}"
48+
assert df.equals(self.observational_df), f"\n{df}\nwas not equal to\n{self.observational_df}"
3849

3950
def test_data_constraints(self):
40-
scenario = Scenario({self.X1, self.X2, self.X3, self.Y}, {self.X1.z3 > 2})
51+
scenario = Scenario({self.X1, self.X2, self.X3, self.Y1, self.Y2}, {self.X1.z3 > 2})
4152
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
4253
df = observational_data_collector.collect_data(index_col=0)
4354
expected = self.observational_df.loc[[2, 3]]
44-
assert df.equals(expected), f"{df}\nwas not equal to\n{expected}"
55+
assert df.equals(expected), f"\n{df}\nwas not equal to\n{expected}"
4556

4657
def test_meta_population(self):
4758
def populate_m(data):

tests/generation_tests/test_abstract_test_case.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,36 @@
11
import unittest
22
import os
33
import pandas as pd
4+
import numpy as np
45
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
56
from causal_testing.specification.causal_specification import Scenario
67
from causal_testing.specification.variable import Input, Output
78
from scipy.stats import uniform, rv_discrete
89
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
910
from causal_testing.testing.causal_test_outcome import Positive
11+
from z3 import And
12+
from enum import Enum
13+
14+
15+
class Car(Enum):
16+
isetta = "vehicle.bmw.isetta"
17+
mkz2017 = "vehicle.lincoln.mkz2017"
18+
19+
def __gt__(self, other):
20+
if self.__class__ is other.__class__:
21+
return self.value > other.value
22+
return NotImplemented
23+
24+
25+
class CarGen(rv_discrete):
26+
cars = dict(enumerate(Car, 1))
27+
inverse_cars = {v: k for k, v in cars.items()}
28+
29+
def ppf(self, q, *args, **kwds):
30+
return np.vectorize(self.cars.get)(np.ceil(len(self.cars) * q))
31+
32+
def cdf(self, q, *args, **kwds):
33+
return np.vectorize(self.inverse_cars.get)(q) / len(Car)
1034

1135

1236
class TestAbstractTestCase(unittest.TestCase):
@@ -28,6 +52,8 @@ def setUp(self) -> None:
2852
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
2953
self.X3 = Input("X3", float, uniform(10, 40))
3054
self.X4 = Input("X4", int, rv_discrete(values=([10], [1])))
55+
self.X5 = Input("X5", bool, rv_discrete(values=(range(2), [0.5, 0.5])))
56+
self.Car = Input("Car", Car, CarGen())
3157
self.Y = Output("Y", int)
3258

3359
def test_generate_concrete_test_cases(self):
@@ -44,6 +70,38 @@ def test_generate_concrete_test_cases(self):
4470
assert len(concrete_tests) == 2, "Expected 2 concrete tests"
4571
assert len(runs) == 2, "Expected 2 runs"
4672

73+
def test_generate_boolean_concrete_test_cases(self):
74+
scenario = Scenario({self.X1, self.X2, self.X3, self.X5})
75+
scenario.setup_treatment_variables()
76+
abstract = AbstractCausalTestCase(
77+
scenario=scenario,
78+
intervention_constraints={
79+
And(scenario.treatment_variables[self.X5.name].z3 == True, scenario.variables[self.X5.name].z3 == False)
80+
},
81+
treatment_variable=self.X5,
82+
expected_causal_effect={self.Y: Positive()},
83+
effect_modifiers=None,
84+
)
85+
concrete_tests, runs = abstract.generate_concrete_tests(2)
86+
assert len(concrete_tests) == 1, "Expected 1 concrete test"
87+
assert len(runs) == 1, "Expected 1 run"
88+
89+
def test_generate_enum_concrete_test_cases(self):
90+
scenario = Scenario({self.Car})
91+
scenario.setup_treatment_variables()
92+
abstract = AbstractCausalTestCase(
93+
scenario=scenario,
94+
intervention_constraints={
95+
scenario.treatment_variables[self.Car.name].z3 != scenario.variables[self.Car.name].z3
96+
},
97+
treatment_variable=self.Car,
98+
expected_causal_effect={self.Y: Positive()},
99+
effect_modifiers=None,
100+
)
101+
concrete_tests, runs = abstract.generate_concrete_tests(2)
102+
assert len(concrete_tests) == 2, "Expected 2 concrete tests"
103+
assert len(runs) == 2, "Expected 2 runs"
104+
47105
def test_str(self):
48106
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
49107
scenario.setup_treatment_variables()

tests/specification_tests/test_variable.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from enum import Enum
33
import z3
4+
from scipy.stats import norm, kstest
45

56
from causal_testing.specification.variable import z3_types, Variable, Input
67

@@ -35,6 +36,44 @@ class Color(Enum):
3536
# z3_types(Color)("color") != z3_types(Color)("color")
3637
self.assertEqual(list(map(str, expected_values)), list(map(str, z3_color_values)))
3738

39+
def test_cast_z3_bool(self):
40+
bip = Input("bip", bool)
41+
s = z3.Solver()
42+
t = z3.Bool("t")
43+
f = z3.Bool("f")
44+
s.add(t)
45+
s.add(z3.Not(f))
46+
s.check()
47+
self.assertEqual(bip.cast(s.model()[t]), True)
48+
self.assertEqual(bip.cast(s.model()[f]), False)
49+
50+
def test_cast_z3_string(self):
51+
ip = Input("bip", str)
52+
s = z3.Solver()
53+
t = z3.String("t")
54+
s.add(t == "hello")
55+
s.check()
56+
self.assertEqual(ip.cast(s.model()[t]), "hello")
57+
58+
def test_sample_flakey(self):
59+
ip = Input("ip", float, norm)
60+
self.assertGreater(kstest(ip.sample(10), norm.cdf).pvalue, 0.95)
61+
62+
def test_cast_enum(self):
63+
class Color(Enum):
64+
"""
65+
Example enum class color.
66+
"""
67+
68+
RED = "RED"
69+
GREEN = "GREEN"
70+
BLUE = "BLUE"
71+
72+
color = Input("color", Color)
73+
74+
dtype, colours = z3.EnumSort("color", ("RED", "GREEN", "BLUE"))
75+
self.assertEqual(color.cast(colours[0]), Color.RED)
76+
3877
def test_z3_value_enum(self):
3978
class Color(Enum):
4079
"""
@@ -89,16 +128,18 @@ class Err:
89128

90129
def test_typestring(self):
91130
class Var(Variable):
92-
"""
93-
The simplest class which will elicit the correct error.
94-
"""
95-
96-
def copy(self, name: str = None):
97-
pass
131+
pass
98132

99133
var = Var("v", int)
100134
self.assertEqual(var.typestring(), "Var")
101135

136+
def test_copy(self):
137+
ip = Input("ip", float, norm)
138+
self.assertNotEqual(ip.copy(), ip)
139+
self.assertEqual(ip.copy().name, ip.name)
140+
self.assertEqual(ip.copy().datatype, ip.datatype)
141+
self.assertEqual(ip.copy().distribution, ip.distribution)
142+
102143

103144
class TestZ3Methods(unittest.TestCase):
104145

tests/testing_tests/test_estimators.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,64 @@ def test_program_15_no_interaction(self):
209209
self.assertEqual(round(ate, 1), 3.5)
210210
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
211211

212+
def test_program_15_no_interaction_ate(self):
213+
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
214+
without product parameter."""
215+
df = self.nhefs_df
216+
covariates = {
217+
"sex",
218+
"race",
219+
"age",
220+
"edu_2",
221+
"edu_3",
222+
"edu_4",
223+
"edu_5",
224+
"exercise_1",
225+
"exercise_2",
226+
"active_1",
227+
"active_2",
228+
"wt71",
229+
"smokeintensity",
230+
"smokeyrs",
231+
}
232+
linear_regression_estimator = LinearRegressionEstimator(("qsmk",), 1, 0, covariates, ("wt82_71",), df)
233+
terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
234+
for term_to_square in terms_to_square:
235+
linear_regression_estimator.add_squared_term_to_df(term_to_square)
236+
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate()
237+
self.assertEqual(round(ate, 1), 3.5)
238+
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [2.6, 4.3])
239+
240+
def test_program_15_no_interaction_ate_calculated(self):
241+
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)
242+
without product parameter."""
243+
df = self.nhefs_df
244+
covariates = {
245+
"sex",
246+
"race",
247+
"age",
248+
"edu_2",
249+
"edu_3",
250+
"edu_4",
251+
"edu_5",
252+
"exercise_1",
253+
"exercise_2",
254+
"active_1",
255+
"active_2",
256+
"wt71",
257+
"smokeintensity",
258+
"smokeyrs",
259+
}
260+
linear_regression_estimator = LinearRegressionEstimator(("qsmk",), 1, 0, covariates, ("wt82_71",), df)
261+
terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
262+
for term_to_square in terms_to_square:
263+
linear_regression_estimator.add_squared_term_to_df(term_to_square)
264+
ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
265+
{k: self.nhefs_df.mean()[k] for k in covariates}
266+
)
267+
self.assertEqual(round(ate, 1), 3.5)
268+
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
269+
212270

213271
class TestCausalForestEstimator(unittest.TestCase):
214272
"""Test the linear regression estimator against the programming exercises in Section 2 of Hernán and Robins [1].

0 commit comments

Comments
 (0)