Skip to content

Commit 999c07c

Browse files
00helloworldkbattocchi
authored andcommitted
Update criterion in _interpreters.py
1 parent d20bbd2 commit 999c07c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

econml/cate_interpreter/_interpreters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import abc
55
import numbers
66
import numpy as np
7+
from packaging import version
8+
import sklearn
79
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
810
from sklearn.utils import check_array
911
from ..policy import PolicyTree
@@ -149,7 +151,7 @@ def __init__(self, *,
149151
self.include_uncertainty = include_model_uncertainty
150152
self.uncertainty_level = uncertainty_level
151153
self.uncertainty_only_on_leaves = uncertainty_only_on_leaves
152-
self.criterion = "mse"
154+
self.criterion = "squared_error" if version.parse(sklearn.__version__) >= version.parse("1.0") else "mse"
153155
self.splitter = splitter
154156
self.max_depth = max_depth
155157
self.min_samples_split = min_samples_split

0 commit comments

Comments
 (0)