Skip to content

Commit 282a822

Browse files
committed
chore(skore): Moving confusion matrix display kwargs into set_style (#2293)
Partially address #1957 This PR add a specialized `set_style` for the `ConfusionMatrixDisplay` such that we move the kwargs specifically in this function instead of putting it inside the `plot` one to simplify the user API. The option `policy` is also set to `"update"` by default because it has a better behaviour by default by not overriding all the default parameter already set.
1 parent 5a67cee commit 282a822

File tree

5 files changed

+58
-54
lines changed

5 files changed

+58
-54
lines changed

examples/model_evaluation/plot_estimator_report.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def operational_decision_cost(y_true, y_pred, amount):
418418
# %%
419419
# More plotting options are available via ``heatmap_kwargs``, which are passed to
420420
# seaborn's heatmap. For example, we can customize the colormap and number format:
421-
cm_display.plot(heatmap_kwargs={"cmap": "Greens", "fmt": ".2e"})
421+
cm_display.set_style(heatmap_kwargs={"cmap": "Greens", "fmt": ".2e"}).plot()
422422
plt.show()
423423

424424
# %%

skore/src/skore/_sklearn/_plot/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class Display(Protocol):
2929
def plot(self, **kwargs: Any) -> None:
3030
"""Display a figure containing the information of the display."""
3131

32-
def set_style(self, **kwargs: Any) -> None:
32+
def set_style(
33+
self, *, policy: Literal["override", "update"] = "update", **kwargs: Any
34+
) -> None:
3335
"""Set the style of the display."""
3436

3537
def frame(self, **kwargs: Any) -> pd.DataFrame:
@@ -136,13 +138,13 @@ def _style_params(self) -> list[str]:
136138
]
137139

138140
def set_style(
139-
self, *, policy: Literal["override", "update"] = "override", **kwargs: Any
141+
self, *, policy: Literal["override", "update"] = "update", **kwargs: Any
140142
):
141143
"""Set the style parameters for the display.
142144
143145
Parameters
144146
----------
145-
policy : Literal["override", "update"], default="override"
147+
policy : Literal["override", "update"], default="update"
146148
Policy to use when setting the style parameters.
147149
If "override", existing settings are set to the provided values.
148150
If "update", existing settings are not changed; only settings that were

skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def plot(
106106
normalize: Literal["true", "pred", "all"] | None = None,
107107
threshold_value: float | None = None,
108108
subplot_by: Literal["split", "estimator", "auto"] | None = "auto",
109-
heatmap_kwargs: dict | None = None,
110-
facet_grid_kwargs: dict | None = None,
111109
):
112110
"""Plot the confusion matrix.
113111
@@ -135,14 +133,6 @@ def plot(
135133
be subplotted. If "auto", the variable will be automatically determined
136134
based on the report type.
137135
138-
heatmap_kwargs : dict, default=None
139-
Additional keyword arguments to be passed to seaborn's
140-
:func:`seaborn.heatmap`.
141-
142-
facet_grid_kwargs : dict, default=None
143-
Additional keyword arguments to be passed to seaborn's
144-
:class:`seaborn.FacetGrid`.
145-
146136
Returns
147137
-------
148138
self : ConfusionMatrixDisplay
@@ -152,8 +142,6 @@ def plot(
152142
normalize=normalize,
153143
threshold_value=threshold_value,
154144
subplot_by=subplot_by,
155-
heatmap_kwargs=heatmap_kwargs,
156-
facet_grid_kwargs=facet_grid_kwargs,
157145
)
158146

159147
def _plot_matplotlib(
@@ -162,8 +150,6 @@ def _plot_matplotlib(
162150
normalize: Literal["true", "pred", "all"] | None = None,
163151
threshold_value: float | None = None,
164152
subplot_by: Literal["split", "estimator", "auto"] | None = "auto",
165-
heatmap_kwargs: dict | None = None,
166-
facet_grid_kwargs: dict | None = None,
167153
) -> None:
168154
"""Matplotlib implementation of the `plot` method.
169155
@@ -184,28 +170,13 @@ def _plot_matplotlib(
184170
The variable to use for subplotting. If None, the confusion matrix will not
185171
be subplotted. If "auto", the variable will be automatically determined
186172
based on the report type.
187-
188-
heatmap_kwargs : dict, default=None
189-
Additional keyword arguments to be passed to seaborn's
190-
:func:`seaborn.heatmap`.
191-
192-
facet_grid_kwargs : dict, default=None
193-
Additional keyword arguments to be passed to seaborn's
194-
:class:`seaborn.FacetGrid`.
195173
"""
196174
subplot_by_validated = self._validate_subplot_by(subplot_by, self.report_type)
197175

198176
if "cross-validation" in self.report_type and subplot_by_validated != "split":
199177
# Aggregate the data across splits and create custom annotations.
200178
default_fmt = ".3f" if normalize else ".1f"
201-
annot_fmt = (
202-
heatmap_kwargs.pop("fmt", default_fmt)
203-
if heatmap_kwargs
204-
else self._default_heatmap_kwargs.pop("fmt", default_fmt)
205-
# if fmt was changed with set_style
206-
if "fmt" in self._default_heatmap_kwargs
207-
else default_fmt
208-
)
179+
annot_fmt = self._default_heatmap_kwargs.get("fmt", default_fmt)
209180
frame = self.frame(normalize=normalize, threshold_value=threshold_value)
210181
aggregated = (
211182
frame.groupby(
@@ -226,13 +197,11 @@ def _plot_matplotlib(
226197
default_fmt = ".2f" if normalize else "d"
227198

228199
heatmap_kwargs_validated = _validate_style_kwargs(
229-
{"fmt": default_fmt, **self._default_heatmap_kwargs},
230-
heatmap_kwargs or {},
200+
{"fmt": default_fmt, **self._default_heatmap_kwargs}, {}
231201
)
232202

233203
facet_grid_kwargs_validated = _validate_style_kwargs(
234-
{"col": subplot_by_validated, **self._default_facet_grid_kwargs},
235-
facet_grid_kwargs or {},
204+
{"col": subplot_by_validated, **self._default_facet_grid_kwargs}, {}
236205
)
237206
grid = sns.FacetGrid(
238207
data=frame,
@@ -605,3 +574,46 @@ def select_threshold_and_format(group):
605574
frames.append(select_threshold_and_format(self.confusion_matrix))
606575

607576
return pd.concat(frames)
577+
578+
# ignore the type signature because we override kwargs by specifying the name of
579+
# the parameters for the user.
580+
def set_style( # type: ignore[override]
581+
self,
582+
*,
583+
policy: Literal["override", "update"] = "update",
584+
heatmap_kwargs: dict | None = None,
585+
facet_grid_kwargs: dict | None = None,
586+
):
587+
"""Set the style parameters for the display.
588+
589+
Parameters
590+
----------
591+
policy : Literal["override", "update"], default="update"
592+
Policy to use when setting the style parameters.
593+
If "override", existing settings are set to the provided values.
594+
If "update", existing settings are not changed; only settings that were
595+
previously unset are changed.
596+
597+
heatmap_kwargs : dict, default=None
598+
Additional keyword arguments to be passed to seaborn's
599+
:func:`seaborn.heatmap`.
600+
601+
facet_grid_kwargs : dict, default=None
602+
Additional keyword arguments to be passed to seaborn's
603+
:class:`seaborn.FacetGrid`.
604+
605+
Returns
606+
-------
607+
self : object
608+
Returns the instance itself.
609+
610+
Raises
611+
------
612+
ValueError
613+
If a style parameter is unknown.
614+
"""
615+
return super().set_style(
616+
policy=policy,
617+
heatmap_kwargs=heatmap_kwargs or {},
618+
facet_grid_kwargs=facet_grid_kwargs or {},
619+
)

skore/tests/unit/displays/confusion_matrix/test_common.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_facet_grid_kwargs(pyplot, fixture_prefix, task, request):
109109
display.plot()
110110
assert display.figure_.get_figheight() == 6
111111

112-
display.plot(facet_grid_kwargs={"height": 8})
112+
display.set_style(facet_grid_kwargs={"height": 8}).plot()
113113
assert display.figure_.get_figheight() == 8
114114

115115
@pytest.mark.parametrize("task", ["binary", "multiclass"])
@@ -130,15 +130,13 @@ def get_ax(display):
130130
display = report.metrics.confusion_matrix()
131131
display.plot()
132132
assert get_ax(display).collections[0].get_cmap().name == "Blues"
133-
display.plot(heatmap_kwargs={"cmap": "Reds"})
133+
display.set_style(heatmap_kwargs={"cmap": "Reds"}).plot()
134134
assert get_ax(display).collections[0].get_cmap().name == "Reds"
135-
display.set_style(heatmap_kwargs={"cmap": "Greens"}, policy="update").plot()
136-
assert get_ax(display).collections[0].get_cmap().name == "Greens"
137135

138136
display = report.metrics.confusion_matrix()
139137
display.plot()
140138
assert len(get_ax(display).texts) > 1
141-
display.plot(heatmap_kwargs={"annot": False})
139+
display.set_style(heatmap_kwargs={"annot": False}).plot()
142140
# There is still the pos_label annotation
143141
assert len(get_ax(display).texts) == n_base_elements
144142
plt.close("all")
@@ -148,25 +146,17 @@ def get_ax(display):
148146
for text in get_ax(display).texts:
149147
text_content = text.get_text()
150148
assert "." in text_content or "*" in text_content
151-
display.plot(normalize="all", heatmap_kwargs={"fmt": ".2e"})
149+
display.set_style(heatmap_kwargs={"fmt": ".2e"}).plot(normalize="all")
152150
for text in get_ax(display).texts:
153151
text_content = text.get_text()
154152
assert "e" in text_content
155-
display.set_style(heatmap_kwargs={"fmt": ".2E"}, policy="update").plot(
156-
normalize="all"
157-
)
158-
for text in get_ax(display).texts:
159-
text_content = text.get_text()
160-
assert "E" in text_content or "*" in text_content
161153
plt.close("all")
162154

163155
display = report.metrics.confusion_matrix()
164156
display.plot()
165157
assert len(display.figure_.axes) == n_plots
166-
display.plot(heatmap_kwargs={"cbar": True})
158+
display.set_style(heatmap_kwargs={"cbar": True}).plot()
167159
assert len(display.figure_.axes) == 2 * n_plots
168-
display.set_style(heatmap_kwargs={"cbar": False}, policy="update").plot()
169-
assert len(display.figure_.axes) == n_plots
170160
plt.close("all")
171161

172162
@pytest.mark.parametrize("task", ["binary", "multiclass"])

skore/tests/unit/displays/test_style.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_style_mixin():
2121
"initial_state, override_value, expected_result, use_explicit_policy",
2222
[
2323
(None, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False),
24-
({"a": 1, "b": 2}, {"c": 3}, {"c": 3}, False),
24+
({"a": 1, "b": 2}, {"c": 3}, {"a": 1, "b": 2, "c": 3}, False),
2525
({"c": 3}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, True),
2626
({"a": 1, "b": 2}, 42, 42, True),
2727
],

0 commit comments

Comments
 (0)