File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 11
11
import statsmodels .formula .api as smf
12
12
from econml .dml import CausalForestDML
13
13
from patsy import dmatrix # pylint: disable = no-name-in-module
14
-
14
+ from patsy import ModelDesc
15
15
from sklearn .ensemble import GradientBoostingRegressor
16
16
from statsmodels .regression .linear_model import RegressionResultsWrapper
17
17
from statsmodels .tools .sm_exceptions import PerfectSeparationError
@@ -351,7 +351,8 @@ def estimate_coefficient(self) -> float:
351
351
"""
352
352
model = self ._run_linear_regression ()
353
353
newline = "\n "
354
- if self .treatment in self .df .dtypes and str (self .df .dtypes [self .treatment ]) == "object" :
354
+ patsy_md = ModelDesc .from_formula (self .treatment )
355
+ if any ((self .df .dtypes [factor .name ()] == 'object' for factor in patsy_md .rhs_termlist [1 ].factors )):
355
356
design_info = dmatrix (self .formula .split ("~" )[1 ], self .df ).design_info
356
357
treatment = design_info .column_names [design_info .term_name_slices [self .treatment ]]
357
358
else :
You can’t perform that action at this time.
0 commit comments