Skip to content

Commit e2e5e99

Browse files
committed
pylint
1 parent 9e2118b commit e2e5e99

File tree

6 files changed

+71
-46
lines changed

6 files changed

+71
-46
lines changed

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,15 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
5656
self.output_path = Path(output_path)
5757
self.check_file_exists(self.output_path, output_overwrite)
5858

59-
def set_paths(self, json_path: str, dag_path: str, data_paths: str = []):
59+
def set_paths(self, json_path: str, dag_path: str, data_paths: list[str] = None):
6060
"""
6161
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
6262
:param json_path: string path representation to .json file containing test specifications
6363
:param dag_path: string path representation to the .dot file containing the Causal DAG
6464
:param data_paths: string path representation to the data files
6565
"""
66+
if data_paths is None:
67+
data_paths = []
6668
self.input_paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_paths=data_paths)
6769

6870
def setup(self, scenario: Scenario):

causal_testing/testing/causal_test_outcome.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def apply(self, res: CausalTestResult) -> bool:
3232
if res.test_value.type == "coefficient":
3333
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
3434
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
35-
return any([0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high)])
35+
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
3636
if res.test_value.type == "risk_ratio":
3737
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
3838
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
@@ -41,18 +41,18 @@ def apply(self, res: CausalTestResult) -> bool:
4141
class NoEffect(CausalTestOutcome):
4242
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
4343

44-
def apply(self, res: CausalTestResult, threshold: float = 1e-10) -> bool:
44+
def apply(self, res: CausalTestResult, atol: float = 1e-10) -> bool:
4545
if res.test_value.type == "ate":
46-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
46+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < atol)
4747
if res.test_value.type == "coefficient":
4848
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
4949
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
5050
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]
51-
return all([ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)]) or all(
52-
[abs(v) < 1e-10 for v in value]
51+
return all(ci_low < 0 < ci_high for ci_low, ci_high in zip(ci_low, ci_high)) or all(
52+
abs(v) < 1e-10 for v in value
5353
)
5454
if res.test_value.type == "risk_ratio":
55-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
55+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=atol)
5656
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
5757

5858

@@ -85,7 +85,8 @@ def apply(self, res: CausalTestResult) -> bool:
8585
return res.test_value.value > 0
8686
if res.test_value.type == "risk_ratio":
8787
return res.test_value.value > 1
88-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
88+
# Dead code?
89+
# raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
8990

9091

9192
class Negative(SomeEffect):
@@ -98,4 +99,5 @@ def apply(self, res: CausalTestResult) -> bool:
9899
return res.test_value.value < 0
99100
if res.test_value.type == "risk_ratio":
100101
return res.test_value.value < 1
101-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
102+
# Dead code?
103+
# raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")

causal_testing/testing/estimators.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -434,35 +434,6 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
434434

435435
return (treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
436436

437-
def estimate_cates(self) -> tuple[float, list[float, float]]:
438-
"""Estimate the conditional average treatment effect of the treatment on the outcome. That is, the change
439-
in outcome caused by changing the treatment variable from the control value to the treatment value.
440-
441-
:return: The conditional average treatment effect and the 95% Wald confidence intervals.
442-
"""
443-
assert (
444-
self.effect_modifiers
445-
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
446-
x = pd.DataFrame()
447-
x[self.treatment] = [self.treatment_value, self.control_value]
448-
x["Intercept"] = 1 # self.intercept
449-
for k, v in self.effect_modifiers.items():
450-
self.adjustment_set.add(k)
451-
x[k] = v
452-
if hasattr(self, "square_terms"):
453-
for t in self.square_terms:
454-
x[t + "^2"] = x[t] ** 2
455-
if hasattr(self, "product_terms"):
456-
for a, b in self.product_terms:
457-
x[f"{a}*{b}"] = x[a] * x[b]
458-
459-
model = self._run_linear_regression()
460-
y = model.predict(x)
461-
treatment_outcome = y.iloc[0]
462-
control_outcome = y.iloc[1]
463-
464-
return treatment_outcome - control_outcome, None
465-
466437
def _run_linear_regression(self) -> RegressionResultsWrapper:
467438
"""Run linear regression of the treatment and adjustment set against the outcome and return the model.
468439
@@ -482,7 +453,6 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
482453
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
483454
cols = [self.treatment]
484455
cols += [x for x in self.adjustment_set if x not in cols]
485-
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
486456
model = smf.ols(formula=self.formula, data=self.df).fit()
487457
return model
488458

tests/json_front_tests/test_json_class.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def setUp(self) -> None:
4343
self.json_class.set_paths(self.json_path, self.dag_path, self.data_path)
4444
self.json_class.setup(self.scenario)
4545

46+
def test_setting_no_path(self):
47+
json_class = JsonUtility("temp_out.txt", True)
48+
json_class.set_paths(self.json_path, self.dag_path, None)
49+
self.assertEqual(json_class.input_paths.data_paths, []) # Needs to be list of Paths
50+
51+
52+
4653
def test_setting_paths(self):
4754
self.assertEqual(self.json_class.input_paths.json_path, Path(self.json_path))
4855
self.assertEqual(self.json_class.input_paths.dag_path, Path(self.dag_path))

tests/testing_tests/test_causal_test_outcome.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_empty_adjustment_set(self):
6464
),
6565
)
6666

67-
def test_Positive_pass(self):
67+
def test_Positive_ate_pass(self):
6868
test_value = TestValue(type="ate", value=5.05)
6969
ctr = CausalTestResult(
7070
estimator=self.estimator,
@@ -75,6 +75,17 @@ def test_Positive_pass(self):
7575
ev = Positive()
7676
self.assertTrue(ev.apply(ctr))
7777

78+
def test_Positive_risk_ratio_pass(self):
79+
test_value = TestValue(type="risk_ratio", value=2)
80+
ctr = CausalTestResult(
81+
estimator=self.estimator,
82+
test_value=test_value,
83+
confidence_intervals=None,
84+
effect_modifier_configuration=None,
85+
)
86+
ev = Positive()
87+
self.assertTrue(ev.apply(ctr))
88+
7889
def test_Positive_fail(self):
7990
test_value = TestValue(type="ate", value=0)
8091
ctr = CausalTestResult(
@@ -97,7 +108,7 @@ def test_Positive_fail_ci(self):
97108
ev = Positive()
98109
self.assertFalse(ev.apply(ctr))
99110

100-
def test_Negative_pass(self):
111+
def test_Negative_ate_pass(self):
101112
test_value = TestValue(type="ate", value=-5.05)
102113
ctr = CausalTestResult(
103114
estimator=self.estimator,
@@ -108,6 +119,17 @@ def test_Negative_pass(self):
108119
ev = Negative()
109120
self.assertTrue(ev.apply(ctr))
110121

122+
def test_Negative_risk_ratio_pass(self):
123+
test_value = TestValue(type="risk_ratio", value=0.2)
124+
ctr = CausalTestResult(
125+
estimator=self.estimator,
126+
test_value=test_value,
127+
confidence_intervals=None,
128+
effect_modifier_configuration=None,
129+
)
130+
ev = Negative()
131+
self.assertTrue(ev.apply(ctr))
132+
111133
def test_Negative_fail(self):
112134
test_value = TestValue(type="ate", value=0)
113135
ctr = CausalTestResult(
@@ -163,17 +185,33 @@ def test_exactValue_fail(self):
163185
ev = ExactValue(5, 0.1)
164186
self.assertFalse(ev.apply(ctr))
165187

166-
def test_someEffect_invalid(self):
188+
def test_invalid(self):
167189
test_value = TestValue(type="invalid", value=5.05)
168190
ctr = CausalTestResult(
169191
estimator=self.estimator,
170192
test_value=test_value,
171193
confidence_intervals=[4.8, 6.7],
172194
effect_modifier_configuration=None,
173195
)
174-
ev = SomeEffect()
175196
with self.assertRaises(ValueError):
176-
ev.apply(ctr)
197+
SomeEffect().apply(ctr)
198+
with self.assertRaises(ValueError):
199+
NoEffect().apply(ctr)
200+
with self.assertRaises(ValueError):
201+
Positive().apply(ctr)
202+
with self.assertRaises(ValueError):
203+
Negative().apply(ctr)
204+
205+
def test_someEffect_pass_coefficient(self):
206+
test_value = TestValue(type="coefficient", value=5.05)
207+
ctr = CausalTestResult(
208+
estimator=self.estimator,
209+
test_value=test_value,
210+
confidence_intervals=[4.8, 6.7],
211+
effect_modifier_configuration=None,
212+
)
213+
self.assertTrue(SomeEffect().apply(ctr))
214+
self.assertFalse(NoEffect().apply(ctr))
177215

178216
def test_someEffect_pass_ate(self):
179217
test_value = TestValue(type="ate", value=5.05)

tests/testing_tests/test_estimators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,14 @@ def setUpClass(cls) -> None:
9393
]
9494
)
9595

96+
# Yes, this probably shouldn't be in here, but it uses the scarf data so it makes more sense to put it
97+
# here than duplicating the scarf data for a single test
98+
def test_linear_regression_categorical_ate(self):
99+
df = self.scarf_df.copy()
100+
logistic_regression_estimator = LinearRegressionEstimator("color", None, None, set(), "completed", df)
101+
ate, confidence = logistic_regression_estimator.estimate_unit_ate()
102+
self.assertTrue(all([ci_low < 0 < ci_high for ci_low, ci_high in zip(confidence[0], confidence[1])]))
103+
96104
def test_ate(self):
97105
df = self.scarf_df.copy()
98106
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, set(), "completed", df)
@@ -212,11 +220,9 @@ def test_program_11_2(self):
212220
model = linear_regression_estimator._run_linear_regression()
213221
ate, _ = linear_regression_estimator.estimate_unit_ate()
214222

215-
print(model.summary())
216223
self.assertEqual(round(model.params["Intercept"] + 90 * model.params["treatments"], 1), 216.9)
217224

218225
# Increasing treatments from 90 to 100 should be the same as 10 times the unit ATE
219-
print("ATE", ate)
220226
self.assertEqual(round(model.params["treatments"], 1), round(ate, 1))
221227

222228
def test_program_11_3(self):

0 commit comments

Comments
 (0)