Skip to content

Commit 1eac357

Browse files
Fetch factors from Patsy to check types
1 parent 4736391 commit 1eac357

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

causal_testing/testing/estimators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import statsmodels.formula.api as smf
1212
from econml.dml import CausalForestDML
1313
from patsy import dmatrix # pylint: disable = no-name-in-module
14-
14+
from patsy import ModelDesc
1515
from sklearn.ensemble import GradientBoostingRegressor
1616
from statsmodels.regression.linear_model import RegressionResultsWrapper
1717
from statsmodels.tools.sm_exceptions import PerfectSeparationError
@@ -351,7 +351,8 @@ def estimate_coefficient(self) -> float:
351351
"""
352352
model = self._run_linear_regression()
353353
newline = "\n"
354-
if self.treatment in self.df.dtypes and str(self.df.dtypes[self.treatment]) == "object":
354+
patsy_md = ModelDesc.from_formula(self.treatment)
355+
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
355356
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
356357
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
357358
else:

0 commit comments

Comments
 (0)