Skip to content

Commit 2f52446

Browse files
committed
Enable support for more recent versions of sklearn
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
1 parent 9489ed5 commit 2f52446

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

.github/workflows/publish-documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
- name: Install graphviz
8080
run: sudo apt-get -yq install graphviz
8181
- name: Build documentation
82-
run: pip install "sphinx~=7.0" "sphinx_rtd_theme~=2.0.0" && sphinx-build ./doc/ ./build/sphinx/html/ -W
82+
run: pip install "sphinx~=7.0" "sphinx_rtd_theme~=2.0.0" scipy-doctest && sphinx-build ./doc/ ./build/sphinx/html/ -W
8383
- name: Upload docs as artifact
8484
uses: actions/upload-artifact@v4
8585
with:

doc/conf.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,22 @@
1818
import econml
1919
sys.path.insert(0, os.path.abspath('econml'))
2020

21+
# Use scipy-doctest's float-tolerant output checker for compatibility
22+
# with inherited sklearn docstrings that use it for approximate float comparisons.
23+
#
24+
# Why monkey-patch instead of switching to scipy-doctest entirely?
25+
# We run doctests via `sphinx-build -b doctest`, which uses sphinx.ext.doctest.
26+
# This lets us use Sphinx-specific directives like `.. testcode::` and
27+
# `.. testoutput::` blocks throughout our documentation. scipy-doctest only
28+
# supports standard Python doctest syntax and cannot parse these directives.
29+
# By monkey-patching the OutputChecker, we get scipy-doctest's float-tolerant
30+
# comparisons while retaining full Sphinx doctest functionality.
31+
try:
32+
from scipy_doctest import DTChecker
33+
doctest.OutputChecker = DTChecker
34+
except ImportError:
35+
pass # Fall back to default if scipy-doctest not installed
36+
2137

2238
# -- Project information -----------------------------------------------------
2339

@@ -257,4 +273,3 @@ def exclude_entity(app, what, name, obj, skip, opts):
257273

258274
def setup(app):
259275
app.connect('autodoc-skip-member', exclude_entity)
260-
()

econml/sklearn_extensions/linear_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,13 +1247,13 @@ def score(self, X, y, sample_weight=None):
12471247
return self.model.score(X, y, sample_weight)
12481248

12491249
def __getattr__(self, key):
1250-
if key in self._known_params:
1250+
if key in self._known_params or key == '__sklearn_tags__':
12511251
return getattr(self.model, key)
12521252
else:
12531253
raise AttributeError("No attribute " + key)
12541254

12551255
def __setattr__(self, key, value):
1256-
if key in self._known_params:
1256+
if key in self._known_params or key == '__sklearn_tags__':
12571257
setattr(self.model, key, value)
12581258
else:
12591259
super().__setattr__(key, value)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
# in addition to dependencies)
2727
"numba > 0.53.1",
2828
"scipy > 1.4.0",
29-
"scikit-learn >= 1.0, < 1.7",
29+
"scikit-learn >= 1.0, < 1.9",
3030
"sparse",
3131
"joblib >= 0.13.0",
3232
"statsmodels >= 0.10",

0 commit comments

Comments
 (0)