Skip to content

Commit 70526c6

Browse files
Add LogisticRegressionEstimator to formula type check
1 parent 2d02e5b commit 70526c6

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

causal_testing/json_front/json_class.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from causal_testing.specification.variable import Input, Meta, Output
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25-
from causal_testing.testing.estimators import Estimator, LinearRegressionEstimator
25+
from causal_testing.testing.estimators import Estimator, LinearRegressionEstimator, LogisticRegressionEstimator
2626
from causal_testing.testing.base_test_case import BaseTestCase
2727
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2828

@@ -309,8 +309,10 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
309309
"""
310310
estimator_kwargs = {}
311311
if "formula" in test:
312-
if test["estimator"] != LinearRegressionEstimator:
313-
raise TypeError("Currently only LinearRegressionEstimator supports the use of formulas")
312+
if test["estimator"] != LinearRegressionEstimator or LogisticRegressionEstimator:
313+
raise TypeError(
314+
"Currently only LinearRegressionEstimator and LogisticRegressionEstimator supports the use of formulas"
315+
)
314316
estimator_kwargs["formula"] = test["formula"]
315317
estimator_kwargs["adjustment_set"] = {}
316318
else:

0 commit comments

Comments
 (0)