Skip to content

Commit 9b0c478

Browse files
Add method to validate formulae
1 parent a2abf7b commit 9b0c478

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

causal_testing/testing/estimators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from statsmodels.tools.sm_exceptions import PerfectSeparationError
1818

1919
from causal_testing.specification.variable import Variable
20+
from causal_testing.specification.causal_dag import CausalDAG
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -122,6 +123,14 @@ def get_terms_from_formula(self):
122123
covariates = rhs_terms.remove(self.treatment)
123124
return outcome, self.treatment, covariates
124125

126+
def validate_formula(self, causal_dag: CausalDAG):
127+
outcome, treatment, covariates = causal_dag.get_terms_from_formula()
128+
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(treatments=[treatment], outcomes=[outcome])
129+
return CausalDAG.constructive_backdoor_criterion(proper_backdoor_graph=proper_backdoor_graph,
130+
treatments=[treatment], outcomes=[outcome],
131+
covariates=list(covariates))
132+
133+
125134
class LogisticRegressionEstimator(Estimator):
126135
"""A Logistic Regression Estimator is a parametric estimator which restricts the variables in the data to a linear
127136
combination of parameters and functions of the variables (note these functions need not be linear). It is designed

0 commit comments

Comments
 (0)