Skip to content

Commit 0b9433a

Browse files
committed
Fixed pytest
1 parent 333a61c commit 0b9433a

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

causal_testing/generation/enum_gen.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from scipy.stats import rv_discrete
2+
3+
4+
class EnumGen(rv_discrete):
5+
def __init__(self, dt: Enum):
6+
self.dt = dict(enumerate(dt, 1))
7+
self.inverse_dt = {v: k for k, v in self.dt.items()}
8+
9+
def ppf(self, q, *args, **kwds):
10+
return np.vectorize(self.dt.get)(np.ceil(len(self.dt) * q))
11+
12+
def cdf(self, q, *args, **kwds):
13+
return np.vectorize(self.inverse_dt.get)(q) / len(Car)

causal_testing/testing/causal_test_outcome.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,26 +64,26 @@ def __str__(self):
6464
return f"ExactValue: {self.value}±{self.tolerance}"
6565

6666

67-
class Positive(CausalTestOutcome):
67+
class Positive(SomeEffect):
6868
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
6969

7070
def apply(self, res: CausalTestResult) -> bool:
7171
if res.ci_valid() and not super().apply(res):
7272
return False
73-
if res.test_value.type == "ate":
73+
if res.test_value.type in {"ate", "coefficient"}:
7474
return res.test_value.value > 0
75-
if res.test_value.type == "risk_ratio":
75+
elif res.test_value.type == "risk_ratio":
7676
return res.test_value.value > 1
7777
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7878

7979

80-
class Negative(CausalTestOutcome):
80+
class Negative(SomeEffect):
8181
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
8282

8383
def apply(self, res: CausalTestResult) -> bool:
8484
if res.ci_valid() and not super().apply(res):
8585
return False
86-
if res.test_value.type == "ate":
86+
if res.test_value.type in {"ate", "coefficient"}:
8787
return res.test_value.value < 0
8888
if res.test_value.type == "risk_ratio":
8989
return res.test_value.value < 1

examples/lr91/example_max_conductances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
153153
return causal_test_result.test_value.value, causal_test_result.confidence_intervals
154154

155155

156-
def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True, show: bool = False):
156+
def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = False, show: bool = False):
157157
"""Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
158158
159159
:param results_dict: A dictionary containing results for sensitivity analysis of each input parameter.

tests/testing_tests/test_estimators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,15 @@ def test_estimate_coefficient(self):
166166
Test we get the correct coefficient.
167167
"""
168168
iv_estimator = InstrumentalVariableEstimator(
169+
df=self.df,
169170
treatment="X",
170171
treatment_value=None,
171172
control_value=None,
172173
adjustment_set=set(),
173174
outcome="Y",
174175
instrument="Z",
175-
df=self.df,
176176
)
177-
self.assertEqual(iv_estimator.estimate_coefficient(), 2)
177+
self.assertEqual(iv_estimator.estimate_coefficient(self.df), 2)
178178

179179

180180
class TestLinearRegressionEstimator(unittest.TestCase):
@@ -192,21 +192,21 @@ def setUpClass(cls) -> None:
192192
def test_program_11_2(self):
193193
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
194194
df = self.chapter_11_df
195-
linear_regression_estimator = LinearRegressionEstimator("treatments", 100, 90, set(), "outcomes", df)
195+
linear_regression_estimator = LinearRegressionEstimator("treatments", None, None, set(), "outcomes", df)
196196
model = linear_regression_estimator._run_linear_regression()
197197
ate, _ = linear_regression_estimator.estimate_unit_ate()
198198

199199
print(model.summary())
200200
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
201201

202202
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
203-
self.assertEqual(round(10 * model.params["treatments"], 1), round(ate, 1))
203+
self.assertEqual(round(model.params["treatments"], 1), round(ate, 1))
204204

205205
def test_program_11_3(self):
206206
"""Test whether our linear regression implementation produces the same results as program 11.3 (p. 144)."""
207207
df = self.chapter_11_df.copy()
208208
linear_regression_estimator = LinearRegressionEstimator(
209-
"treatments", 100, 90, set(), "outcomes", df, formula="outcomes ~ treatments + np.power(treatments, 2)"
209+
"treatments", None, None, set(), "outcomes", df, formula="outcomes ~ treatments + np.power(treatments, 2)"
210210
)
211211
model = linear_regression_estimator._run_linear_regression()
212212
ate, _ = linear_regression_estimator.estimate_unit_ate()
@@ -220,7 +220,7 @@ def test_program_11_3(self):
220220
197.1,
221221
)
222222
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
223-
self.assertEqual(round(10 * model.params["treatments"], 3), round(ate, 3))
223+
self.assertEqual(round(model.params["treatments"], 3), round(ate, 3))
224224

225225
def test_program_15_1A(self):
226226
"""Test whether our linear regression implementation produces the same results as program 15.1 (p. 163, 184)."""

0 commit comments

Comments
 (0)