Skip to content

Commit 5e8a2a0

Browse files
committed
Causal validation techniques
1 parent 3761f75 commit 5e8a2a0

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

causal_testing/testing/validation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""This module contains the CausalValidator class for performing Quantitive Bias Analysis techniques"""
2+
import math
3+
import numpy as np
4+
from scipy.stats import t
5+
from statsmodels.regression.linear_model import RegressionResultsWrapper
6+
7+
class CausalValidator:
8+
9+
def estimate_robustness(self, model: RegressionResultsWrapper, q=1, alpha=1):
10+
"""Calculate the robustness of a linear regression model. This allow
11+
the user to identify how large an unidentified confounding variable
12+
would need to be to nullify the causal relationship under test."""
13+
14+
dof = model.df_resid
15+
t_values = model.tvalues
16+
17+
fq = q * abs(t_values / math.sqrt(dof))
18+
f_crit = abs(t.ppf(alpha / 2, dof - 1)) / math.sqrt(dof - 1)
19+
fqa = fq - f_crit
20+
21+
rv = 0.5 * (np.sqrt(fqa**4 + (4 * fqa**2)) - fqa**2)
22+
23+
return rv
24+
25+
def estimate_e_value(
26+
self, risk_ratio, confidence_intervals: tuple[float, float]
27+
) -> tuple[float, tuple[float, float]]:
28+
"""Calculate the E value from a risk ratio. This allow
29+
the user to identify how large a risk an unidentified confounding
30+
variable would need to be to nullify the causal relationship
31+
under test."""
32+
33+
if risk_ratio >= 1:
34+
e = risk_ratio + math.sqrt(risk_ratio * (risk_ratio - 1))
35+
36+
lower_limit = confidence_intervals[0]
37+
if lower_limit <= 1:
38+
lower_limit = 1
39+
else:
40+
lower_limit = lower_limit + math.sqrt(lower_limit * (lower_limit - 1))
41+
42+
return (e, (lower_limit, 1))
43+
44+
else:
45+
risk_ratio_prime = 1 / risk_ratio
46+
e = risk_ratio_prime + math.sqrt(risk_ratio_prime * (risk_ratio_prime - 1))
47+
48+
upper_limit = confidence_intervals[1]
49+
if upper_limit >= 1:
50+
upper_limit = 1
51+
else:
52+
upper_limit_prime = 1 / upper_limit
53+
upper_limit = upper_limit_prime + math.sqrt(upper_limit_prime * (upper_limit_prime - 1))
54+
55+
return (e, (1, upper_limit))

0 commit comments

Comments
 (0)