Skip to content

Commit d670878

Browse files
committed
MAINT: more refactoring by moving more files into _repr_html
1 parent 306d50e commit d670878

File tree

14 files changed

+173
-203
lines changed

14 files changed

+173
-203
lines changed

sklearn/base.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
from . import __version__
1717
from ._config import config_context, get_config
1818
from .exceptions import InconsistentVersionWarning
19-
from .utils._estimator_html_repr import _HTMLDocumentationLinkMixin, estimator_html_repr
2019
from .utils._metadata_requests import _MetadataRequester, _routing_enabled
2120
from .utils._missing import is_scalar_nan
2221
from .utils._param_validation import validate_parameter_constraints
23-
from .utils._repr_html.base import ReprHTMLMixin
22+
from .utils._repr_html.base import ReprHTMLMixin, _HTMLDocumentationLinkMixin
23+
from .utils._repr_html.estimator import estimator_html_repr
2424
from .utils._repr_html.params import ParamsDict
2525
from .utils._set_output import _SetOutputMixin
2626
from .utils._tags import (
@@ -197,6 +197,8 @@ class BaseEstimator(ReprHTMLMixin, _HTMLDocumentationLinkMixin, _MetadataRequest
197197
array([3, 3, 3])
198198
"""
199199

200+
_html_repr = estimator_html_repr
201+
200202
@classmethod
201203
def _get_param_names(cls):
202204
"""Get parameter names for the estimator"""
@@ -473,36 +475,6 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
473475
caller_name=self.__class__.__name__,
474476
)
475477

476-
@property
477-
def _repr_html_(self):
478-
"""HTML representation of estimator.
479-
480-
This is redundant with the logic of `_repr_mimebundle_`. The latter
481-
should be favored in the long term, `_repr_html_` is only
482-
implemented for consumers who do not interpret `_repr_mimbundle_`.
483-
"""
484-
if get_config()["display"] != "diagram":
485-
raise AttributeError(
486-
"_repr_html_ is only defined when the "
487-
"'display' configuration option is set to "
488-
"'diagram'"
489-
)
490-
return self._repr_html_inner
491-
492-
def _repr_html_inner(self):
493-
"""This function is returned by the @property `_repr_html_` to make
494-
`hasattr(estimator, "_repr_html_") return `True` or `False` depending
495-
on `get_config()["display"]`.
496-
"""
497-
return estimator_html_repr(self)
498-
499-
def _repr_mimebundle_(self, **kwargs):
500-
"""Mime bundle used by jupyter kernels to display estimator"""
501-
output = {"text/plain": repr(self)}
502-
if get_config()["display"] == "diagram":
503-
output["text/html"] = estimator_html_repr(self)
504-
return output
505-
506478

507479
class ClassifierMixin:
508480
"""Mixin class for all classifiers in scikit-learn.

sklearn/compose/_column_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
from ..pipeline import _fit_transform_one, _name_estimators, _transform_one
2121
from ..preprocessing import FunctionTransformer
2222
from ..utils import Bunch
23-
from ..utils._estimator_html_repr import _VisualBlock
2423
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_indexing
2524
from ..utils._metadata_requests import METHODS
2625
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
26+
from ..utils._repr_html.estimator import _VisualBlock
2727
from ..utils._set_output import (
2828
_get_container_adapter,
2929
_get_output_config,

sklearn/ensemble/_stacking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
from ..model_selection import check_cv, cross_val_predict
2525
from ..preprocessing import LabelEncoder
2626
from ..utils import Bunch
27-
from ..utils._estimator_html_repr import _VisualBlock
2827
from ..utils._param_validation import HasMethods, StrOptions
28+
from ..utils._repr_html.estimator import _VisualBlock
2929
from ..utils.metadata_routing import (
3030
MetadataRouter,
3131
MethodMapping,

sklearn/model_selection/_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
get_scorer_names,
3232
)
3333
from ..utils import Bunch, check_random_state
34-
from ..utils._estimator_html_repr import _VisualBlock
3534
from ..utils._param_validation import HasMethods, Interval, StrOptions
35+
from ..utils._repr_html.estimator import _VisualBlock
3636
from ..utils._tags import get_tags
3737
from ..utils.metadata_routing import (
3838
MetadataRouter,

sklearn/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .exceptions import NotFittedError
1717
from .preprocessing import FunctionTransformer
1818
from .utils import Bunch
19-
from .utils._estimator_html_repr import _VisualBlock
2019
from .utils._metadata_requests import METHODS
2120
from .utils._param_validation import HasMethods, Hidden
21+
from .utils._repr_html.estimator import _VisualBlock
2222
from .utils._set_output import (
2323
_get_container_adapter,
2424
_safe_set_output,

sklearn/preprocessing/_function_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import numpy as np
88

99
from ..base import BaseEstimator, TransformerMixin, _fit_context
10-
from ..utils._estimator_html_repr import _VisualBlock
1110
from ..utils._param_validation import StrOptions
11+
from ..utils._repr_html.estimator import _VisualBlock
1212
from ..utils._set_output import (
1313
_get_adapter_from_container,
1414
_get_output_config,

sklearn/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from . import metadata_routing
88
from ._bunch import Bunch
99
from ._chunking import gen_batches, gen_even_slices
10-
from ._estimator_html_repr import estimator_html_repr
1110

1211
# Make _safe_indexing importable from here for backward compat as this particular
1312
# helper is considered semi-private and typically very useful for third-party
@@ -20,6 +19,8 @@
2019
shuffle,
2120
)
2221
from ._mask import safe_mask
22+
from ._repr_html.base import _HTMLDocumentationLinkMixin # noqa: F401
23+
from ._repr_html.estimator import estimator_html_repr
2324
from ._tags import (
2425
ClassifierTags,
2526
InputTags,

sklearn/utils/_repr_html/base.py

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,123 @@
11
# Authors: The scikit-learn developers
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
from sklearn._config import get_config
4+
import itertools
5+
6+
from ... import __version__
7+
from ..._config import get_config
8+
from ..fixes import parse_version
9+
10+
11+
class _HTMLDocumentationLinkMixin:
12+
"""Mixin class allowing to generate a link to the API documentation.
13+
14+
This mixin relies on three attributes:
15+
- `_doc_link_module`: it corresponds to the root module (e.g. `sklearn`). Using this
16+
mixin, the default value is `sklearn`.
17+
- `_doc_link_template`: it corresponds to the template used to generate the
18+
link to the API documentation. Using this mixin, the default value is
19+
`"https://scikit-learn.org/{version_url}/modules/generated/
20+
{estimator_module}.{estimator_name}.html"`.
21+
- `_doc_link_url_param_generator`: it corresponds to a function that generates the
22+
parameters to be used in the template when the estimator module and name are not
23+
sufficient.
24+
25+
The method :meth:`_get_doc_link` generates the link to the API documentation for a
26+
given estimator.
27+
28+
This useful provides all the necessary states for
29+
:func:`sklearn.utils.estimator_html_repr` to generate a link to the API
30+
documentation for the estimator HTML diagram.
31+
32+
Examples
33+
--------
34+
If the default values for `_doc_link_module`, `_doc_link_template` are not suitable,
35+
then you can override them and provide a method to generate the URL parameters:
36+
>>> from sklearn.base import BaseEstimator
37+
>>> doc_link_template = "https://address.local/{single_param}.html"
38+
>>> def url_param_generator(estimator):
39+
... return {"single_param": estimator.__class__.__name__}
40+
>>> class MyEstimator(BaseEstimator):
41+
... # use "builtins" since it is the associated module when declaring
42+
... # the class in a docstring
43+
... _doc_link_module = "builtins"
44+
... _doc_link_template = doc_link_template
45+
... _doc_link_url_param_generator = url_param_generator
46+
>>> estimator = MyEstimator()
47+
>>> estimator._get_doc_link()
48+
'https://address.local/MyEstimator.html'
49+
50+
If instead of overriding the attributes inside the class definition, you want to
51+
override a class instance, you can use `types.MethodType` to bind the method to the
52+
instance:
53+
>>> import types
54+
>>> estimator = BaseEstimator()
55+
>>> estimator._doc_link_template = doc_link_template
56+
>>> estimator._doc_link_url_param_generator = types.MethodType(
57+
... url_param_generator, estimator)
58+
>>> estimator._get_doc_link()
59+
'https://address.local/BaseEstimator.html'
60+
"""
61+
62+
_doc_link_module = "sklearn"
63+
_doc_link_url_param_generator = None
64+
65+
@property
66+
def _doc_link_template(self):
67+
sklearn_version = parse_version(__version__)
68+
if sklearn_version.dev is None:
69+
version_url = f"{sklearn_version.major}.{sklearn_version.minor}"
70+
else:
71+
version_url = "dev"
72+
return getattr(
73+
self,
74+
"__doc_link_template",
75+
(
76+
f"https://scikit-learn.org/{version_url}/modules/generated/"
77+
"{estimator_module}.{estimator_name}.html"
78+
),
79+
)
80+
81+
@_doc_link_template.setter
82+
def _doc_link_template(self, value):
83+
setattr(self, "__doc_link_template", value)
84+
85+
def _get_doc_link(self):
86+
"""Generates a link to the API documentation for a given estimator.
87+
88+
This method generates the link to the estimator's documentation page
89+
by using the template defined by the attribute `_doc_link_template`.
90+
91+
Returns
92+
-------
93+
url : str
94+
The URL to the API documentation for this estimator. If the estimator does
95+
not belong to module `_doc_link_module`, the empty string (i.e. `""`) is
96+
returned.
97+
"""
98+
if self.__class__.__module__.split(".")[0] != self._doc_link_module:
99+
return ""
100+
101+
if self._doc_link_url_param_generator is None:
102+
estimator_name = self.__class__.__name__
103+
# Construct the estimator's module name, up to the first private submodule.
104+
# This works because in scikit-learn all public estimators are exposed at
105+
# that level, even if they actually live in a private sub-module.
106+
estimator_module = ".".join(
107+
itertools.takewhile(
108+
lambda part: not part.startswith("_"),
109+
self.__class__.__module__.split("."),
110+
)
111+
)
112+
return self._doc_link_template.format(
113+
estimator_module=estimator_module, estimator_name=estimator_name
114+
)
115+
return self._doc_link_template.format(**self._doc_link_url_param_generator())
5116

6117

7118
class ReprHTMLMixin:
8119
@property
9120
def _repr_html_(self):
10-
# Taken from sklearn.base.BaseEstimator
11121
"""HTML representation of estimator.
12122
This is redundant with the logic of `_repr_mimebundle_`. The latter
13123
should be favored in the long term, `_repr_html_` is only
@@ -22,15 +132,11 @@ def _repr_html_(self):
22132
return self._repr_html_inner
23133

24134
def _repr_html_inner(self):
25-
from sklearn.utils._repr_html.params import _html_template
26-
27-
return _html_template(self)
135+
return self._html_repr()
28136

29137
def _repr_mimebundle_(self, **kwargs):
30138
"""Mime bundle used by jupyter kernels to display estimator"""
31-
from sklearn.utils._repr_html.params import _html_template
32-
33139
output = {"text/plain": repr(self)}
34140
if get_config()["display"] == "diagram":
35-
output["text/html"] = _html_template(self)
141+
output["text/html"] = self._html_repr()
36142
return output
File renamed without changes.

0 commit comments

Comments
 (0)