Skip to content

Commit 9d32eb1

Browse files
committed
linting
1 parent f91f392 commit 9d32eb1

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

causal_testing/generation/enum_gen.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""This module contains the class EnumGen, which allows us to easily create generating uniform distributions from enums."""
22

3-
from scipy.stats import rv_discrete
43
from enum import Enum
4+
from scipy.stats import rv_discrete
55
import numpy as np
66

77

88
class EnumGen(rv_discrete):
9-
"""This class allows us to easily create generating uniform distributions from enums. This is helpful for generating concrete test inputs from abstract test cases."""
9+
"""This class allows us to easily create generating uniform distributions
10+
from enums. This is helpful for generating concrete test inputs from
11+
abstract test cases."""
1012

1113
def __init__(self, datatype: Enum):
12-
self.dt = dict(enumerate(datatype, 1))
13-
self.inverse_dt = {v: k for k, v in self.dt.items()}
14+
self.datatype = dict(enumerate(datatype, 1))
15+
self.inverse_dt = {v: k for k, v in self.datatype.items()}
1416

1517
def ppf(self, q):
1618
"""Percent point function (inverse of `cdf`) at q of the given RV.
@@ -23,7 +25,7 @@ def ppf(self, q):
2325
k : array_like
2426
Quantile corresponding to the lower tail probability, q.
2527
"""
26-
return np.vectorize(self.dt.get)(np.ceil(len(self.dt) * q))
28+
return np.vectorize(self.datatype.get)(np.ceil(len(self.datatype) * q))
2729

2830
def cdf(self, q):
2931
"""
@@ -37,4 +39,4 @@ def cdf(self, q):
3739
cdf : ndarray
3840
Cumulative distribution function evaluated at `x`
3941
"""
40-
return np.vectorize(self.inverse_dt.get)(q) / len(self.dt)
42+
return np.vectorize(self.inverse_dt.get)(q) / len(self.datatype)

causal_testing/testing/causal_test_outcome.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class SomeEffect(CausalTestOutcome):
2626
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
2727

2828
def apply(self, res: CausalTestResult) -> bool:
29-
if res.test_value.type == "ate" or res.test_value.type == "coefficient":
29+
if res.test_value.type in {"ate", "coefficient"}:
3030
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
3131
if res.test_value.type == "risk_ratio":
3232
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
@@ -38,7 +38,7 @@ class NoEffect(CausalTestOutcome):
3838

3939
def apply(self, res: CausalTestResult) -> bool:
4040
print("RESULT", res)
41-
if res.test_value.type == "ate" or res.test_value.type == "coefficient":
41+
if res.test_value.type in {"ate", "coefficient"}:
4242
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
4343
if res.test_value.type == "risk_ratio":
4444
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
@@ -72,7 +72,7 @@ def apply(self, res: CausalTestResult) -> bool:
7272
return False
7373
if res.test_value.type in {"ate", "coefficient"}:
7474
return res.test_value.value > 0
75-
elif res.test_value.type == "risk_ratio":
75+
if res.test_value.type == "risk_ratio":
7676
return res.test_value.value > 1
7777
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
7878

causal_testing/testing/estimators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,8 @@ def add_modelling_assumptions(self):
529529

530530
def estimate_coefficient(self, df):
531531
"""
532-
Estimate the linear regression coefficient of the treatment on the outcome.
532+
Estimate the linear regression coefficient of the treatment on the
533+
outcome.
533534
"""
534535
# Estimate the total effect of instrument I on outcome Y = abI + c1
535536
ab = sm.OLS(df[self.outcome], df[[self.instrument]]).fit().params[self.instrument]
@@ -541,6 +542,10 @@ def estimate_coefficient(self, df):
541542
return ab / a
542543

543544
def estimate_unit_ate(self, bootstrap_size=100):
545+
"""
546+
Estimate the unit ate (i.e. coefficient) of the treatment on the
547+
outcome.
548+
"""
544549
bootstraps = sorted(
545550
[self.estimate_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
546551
)

0 commit comments

Comments
 (0)