Skip to content

Commit 81aa88b

Browse files
committed
done
1 parent 87fd814 commit 81aa88b

File tree

6 files changed

+159
-104
lines changed

6 files changed

+159
-104
lines changed

skore/src/skore/_sklearn/_estimator/metrics_accessor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,10 @@ class is set to the one provided when creating the report. If `None`,
407407
names=["Metric", "Label / Average"],
408408
)
409409
else:
410-
index = pd.Index([metric_name], name="Metric")
410+
index = pd.Index([metric_name], name="Metric", dtype=object)
411411
score_array = np.array(score).reshape(-1, 1)
412412
else:
413-
index = pd.Index([metric_name], name="Metric")
413+
index = pd.Index([metric_name], name="Metric", dtype=object)
414414
score_array = np.array(score).reshape(-1, 1)
415415
elif self._parent._ml_task == "multiclass-classification":
416416
if isinstance(score, dict):
@@ -430,7 +430,7 @@ class is set to the one provided when creating the report. If `None`,
430430
)
431431
score_array = np.array(score).reshape(-1, 1)
432432
else:
433-
index = pd.Index([metric_name], name="Metric")
433+
index = pd.Index([metric_name], name="Metric", dtype=object)
434434
score_array = np.array(score).reshape(-1, 1)
435435
elif self._parent._ml_task in ("regression", "multioutput-regression"):
436436
if isinstance(score, list):
@@ -440,7 +440,7 @@ class is set to the one provided when creating the report. If `None`,
440440
)
441441
score_array = np.array(score).reshape(-1, 1)
442442
else:
443-
index = pd.Index([metric_name], name="Metric")
443+
index = pd.Index([metric_name], name="Metric", dtype=object)
444444
score_array = np.array(score).reshape(-1, 1)
445445
else: # unknown task - try our best
446446
index = None if isinstance(score, Iterable) else [metric_name]

skore/src/skore/_sklearn/_plot/data/table_report.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,12 @@ def _truncate_top_k_categories(
6565
{v: ellide_string(v, max_len=20) for v in values if isinstance(v, str)}
6666
)
6767
else:
68+
original_dtype = col.dtype
6869
col[~keep] = other_label
6970
col = col.apply(
7071
lambda x: ellide_string(x, max_len=20) if isinstance(x, str) else x
7172
)
72-
col = col.astype(object)
73+
col = col.astype(original_dtype)
7374
return col
7475

7576

@@ -290,7 +291,7 @@ def _plot_matplotlib(
290291
hue: str | None = None,
291292
kind: Literal["dist", "corr"] = "dist",
292293
top_k_categories: int = 20,
293-
) -> None:
294+
) -> tuple[Figure, Axes]:
294295
"""Matplotlib implementation of the `plot` method."""
295296
self.figure_, self.ax_ = plt.subplots()
296297
if kind == "dist":
@@ -306,6 +307,8 @@ def _plot_matplotlib(
306307
y=y,
307308
k=top_k_categories,
308309
histplot_kwargs=self._default_histplot_kwargs,
310+
figure=self.figure_,
311+
ax=self.ax_,
309312
)
310313
case _:
311314
self._plot_distribution_2d(
@@ -317,6 +320,8 @@ def _plot_matplotlib(
317320
stripplot_kwargs=self._default_stripplot_kwargs,
318321
boxplot_kwargs=self._default_boxplot_kwargs,
319322
heatmap_kwargs=self._default_heatmap_kwargs,
323+
figure=self.figure_,
324+
ax=self.ax_,
320325
)
321326

322327
elif kind == "corr":
@@ -327,18 +332,22 @@ def _plot_matplotlib(
327332
raise ValueError(
328333
f"When {kind=!r}, {param_name!r} argument must be None."
329334
)
330-
self._plot_cramer(heatmap_kwargs=self._default_heatmap_kwargs)
335+
self._plot_cramer(heatmap_kwargs=self._default_heatmap_kwargs, ax=self.ax_)
331336

332337
else:
333338
raise ValueError(f"'kind' options are 'dist', 'corr', got {kind!r}.")
334339

340+
return (self.figure_, self.ax_)
341+
335342
def _plot_distribution_1d(
336343
self,
337344
*,
338345
x: str | None,
339346
y: str | None,
340347
k: int,
341348
histplot_kwargs: dict[str, Any],
349+
figure: Figure,
350+
ax: Axes,
342351
) -> None:
343352
"""Plot 1-dimensional distribution of a feature.
344353
@@ -387,16 +396,16 @@ def _plot_distribution_1d(
387396
histplot_params = {"x": column}
388397
despine_params = {"bottom": is_categorical}
389398
if duration_unit is not None:
390-
self.ax_.set(xlabel=f"{duration_unit.capitalize()}s")
399+
ax.set(xlabel=f"{duration_unit.capitalize()}s")
391400
else: # y is not None
392401
histplot_params = {"y": column}
393402
despine_params = {"left": is_categorical}
394403
if duration_unit is not None:
395-
self.ax_.set(ylabel=f"{duration_unit.capitalize()}s")
404+
ax.set(ylabel=f"{duration_unit.capitalize()}s")
396405

397-
sns.histplot(ax=self.ax_, **histplot_params, **histplot_kwargs_validated)
406+
sns.histplot(ax=ax, **histplot_params, **histplot_kwargs_validated)
398407
sns.despine(
399-
self.figure_,
408+
figure,
400409
top=True,
401410
right=True,
402411
trim=True,
@@ -406,17 +415,17 @@ def _plot_distribution_1d(
406415

407416
if is_categorical:
408417
_resize_categorical_axis(
409-
figure=self.figure_,
410-
ax=self.ax_,
418+
figure=figure,
419+
ax=ax,
411420
n_categories=sbd.n_unique(column),
412421
is_x_axis=x is not None,
413422
)
414423

415424
if x is not None and any(
416-
len(label.get_text()) > 1 for label in self.ax_.get_xticklabels()
425+
len(label.get_text()) > 1 for label in ax.get_xticklabels()
417426
):
418427
# rotate only for string longer than 1 character
419-
_rotate_ticklabels(self.ax_, rotation=45)
428+
_rotate_ticklabels(ax, rotation=45)
420429

421430
def _plot_distribution_2d(
422431
self,
@@ -429,6 +438,8 @@ def _plot_distribution_2d(
429438
scatterplot_kwargs: dict[str, Any],
430439
hue: str | None = None,
431440
k: int = 20,
441+
figure: Figure,
442+
ax: Axes,
432443
) -> None:
433444
"""Plot 2-dimensional distribution of two features.
434445
@@ -478,7 +489,7 @@ def _plot_distribution_2d(
478489
x=x,
479490
y=y,
480491
hue=hue,
481-
ax=self.ax_,
492+
ax=ax,
482493
**scatterplot_kwargs_validated,
483494
)
484495
elif is_x_num or is_y_num:
@@ -512,23 +523,21 @@ def _plot_distribution_2d(
512523
else:
513524
x = _truncate_top_k_categories(x, k)
514525

515-
sns.boxplot(x=x, y=y, ax=self.ax_, **boxplot_kwargs_validated)
516-
sns.stripplot(x=x, y=y, hue=hue, ax=self.ax_, **stripplot_kwargs_validated)
526+
sns.boxplot(x=x, y=y, ax=ax, **boxplot_kwargs_validated)
527+
sns.stripplot(x=x, y=y, hue=hue, ax=ax, **stripplot_kwargs_validated)
517528

518529
_resize_categorical_axis(
519-
figure=self.figure_,
520-
ax=self.ax_,
530+
figure=figure,
531+
ax=ax,
521532
n_categories=sbd.n_unique(y) if is_x_num else sbd.n_unique(x),
522533
is_x_axis=not is_x_num,
523534
)
524535
if is_x_num:
525536
despine_params["left"] = True
526537
else:
527538
despine_params["bottom"] = True
528-
if any(
529-
len(label.get_text()) > 1 for label in self.ax_.get_xticklabels()
530-
):
531-
_rotate_ticklabels(self.ax_, rotation=45)
539+
if any(len(label.get_text()) > 1 for label in ax.get_xticklabels()):
540+
_rotate_ticklabels(ax, rotation=45)
532541
else:
533542
if (hue is not None) and (not sbd.is_numeric(hue)):
534543
raise ValueError(
@@ -576,9 +585,9 @@ def _plot_distribution_2d(
576585
},
577586
heatmap_kwargs,
578587
)
579-
sns.heatmap(contingency_table, ax=self.ax_, **heatmap_kwargs_validated)
588+
sns.heatmap(contingency_table, ax=ax, **heatmap_kwargs_validated)
580589
despine_params.update(left=True, bottom=True)
581-
self.ax_.tick_params(axis="both", length=0)
590+
ax.tick_params(axis="both", length=0)
582591

583592
for is_x_axis, x_or_y in zip(
584593
[True, False],
@@ -589,26 +598,28 @@ def _plot_distribution_2d(
589598
strict=False,
590599
):
591600
_resize_categorical_axis(
592-
figure=self.figure_,
593-
ax=self.ax_,
601+
figure=figure,
602+
ax=ax,
594603
n_categories=sbd.n_unique(x_or_y),
595604
is_x_axis=is_x_axis,
596605
size_per_category=size_per_category,
597606
)
598607

599-
sns.despine(self.figure_, **despine_params)
608+
sns.despine(figure, **despine_params)
600609

601-
self.ax_.set(xlabel=sbd.name(x), ylabel=sbd.name(y))
602-
if self.ax_.legend_ is not None:
603-
sns.move_legend(self.ax_, (1.05, 0.0))
610+
ax.set(xlabel=sbd.name(x), ylabel=sbd.name(y))
611+
if ax.legend_ is not None:
612+
sns.move_legend(ax, (1.05, 0.0))
604613

605-
def _plot_cramer(self, *, heatmap_kwargs: dict[str, Any]) -> None:
614+
def _plot_cramer(self, *, heatmap_kwargs: dict[str, Any], ax: Axes) -> None:
606615
"""Plot Cramer's V correlation among all columns.
607616
608617
Parameters
609618
----------
610619
heatmap_kwargs : dict, default=None
611620
Keyword arguments to be passed to heatmap.
621+
ax : Axes
622+
The axes to plot on.
612623
"""
613624
heatmap_kwargs_validated = _validate_style_kwargs(
614625
{
@@ -632,8 +643,8 @@ def _plot_cramer(self, *, heatmap_kwargs: dict[str, Any]) -> None:
632643
# and keep the diagonal as well.
633644
mask = np.triu(np.ones_like(cramer_v_table, dtype=bool), k=1)
634645

635-
sns.heatmap(cramer_v_table, mask=mask, ax=self.ax_, **heatmap_kwargs_validated)
636-
self.ax_.set(title="Cramer's V Correlation")
646+
sns.heatmap(cramer_v_table, mask=mask, ax=ax, **heatmap_kwargs_validated)
647+
ax.set(title="Cramer's V Correlation")
637648

638649
def frame(
639650
self, *, kind: Literal["dataset", "top-associations"] = "dataset"

skore/src/skore/_sklearn/_plot/metrics/confusion_matrix.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import pandas as pd
6-
import seaborn
76
import seaborn as sns
87
from numpy.typing import NDArray
98
from sklearn.metrics import confusion_matrix as sklearn_confusion_matrix
@@ -54,6 +53,21 @@ class ConfusionMatrixDisplay(_ClassifierDisplayMixin, DisplayMixin):
5453
The estimator's method that was used to get the predictions. The possible
5554
values are: "predict", "predict_proba", and "decision_function".
5655
56+
Attributes
57+
----------
58+
thresholds : ndarray of shape (n_thresholds,)
59+
Thresholds of the decision function. Each threshold is associated with a
60+
confusion matrix. Only available for binary classification. Thresholds are
61+
sorted in ascending order.
62+
63+
facet_ : seaborn FacetGrid
64+
FacetGrid containing the confusion matrix.
65+
66+
figure_ : matplotlib Figure
67+
Figure containing the confusion matrix.
68+
69+
ax_ : matplotlib Axes
70+
Axes with confusion matrix.
5771
"""
5872

5973
_default_heatmap_kwargs: dict = {
@@ -95,7 +109,7 @@ def plot(
95109
normalize: Literal["true", "pred", "all"] | None = None,
96110
threshold_value: float | None = None,
97111
subplot_by: Literal["split", "estimator", "auto"] | None = "auto",
98-
) -> seaborn.FacetGrid:
112+
):
99113
"""Plot the confusion matrix.
100114
101115
In binary classification, the confusion matrix can be displayed at various
@@ -124,8 +138,8 @@ def plot(
124138
125139
Returns
126140
-------
127-
facet : seaborn.FacetGrid
128-
Confusion matrix visualization.
141+
self : ConfusionMatrixDisplay
142+
Configured with the confusion matrix.
129143
"""
130144
return self._plot(
131145
normalize=normalize,
@@ -139,7 +153,7 @@ def _plot_matplotlib(
139153
normalize: Literal["true", "pred", "all"] | None = None,
140154
threshold_value: float | None = None,
141155
subplot_by: Literal["split", "estimator", "auto"] | None = "auto",
142-
) -> seaborn.FacetGrid:
156+
) -> sns.axisgrid.FacetGrid:
143157
"""Matplotlib implementation of the `plot` method.
144158
145159
Parameters
@@ -192,11 +206,11 @@ def _plot_matplotlib(
192206
facet_grid_kwargs_validated = _validate_style_kwargs(
193207
{"col": subplot_by_validated, **self._default_facet_grid_kwargs}, {}
194208
)
195-
facet_ = sns.FacetGrid(
209+
self.facet_ = sns.FacetGrid(
196210
data=frame,
197211
**facet_grid_kwargs_validated,
198212
)
199-
figure_, ax_ = facet_.figure, facet_.axes.flatten()
213+
self.figure_, self.ax_ = self.facet_.figure, self.facet_.axes.flatten()
200214

201215
def plot_heatmap(data, **kwargs):
202216
"""Plot heatmap for each facet."""
@@ -214,7 +228,7 @@ def plot_heatmap(data, **kwargs):
214228

215229
sns.heatmap(heatmap_data, **kwargs)
216230

217-
facet_.map_dataframe(plot_heatmap, **heatmap_kwargs_validated)
231+
self.facet_.map_dataframe(plot_heatmap, **heatmap_kwargs_validated)
218232

219233
info_data_source = (
220234
f"Data source: {self.data_source.capitalize()} set"
@@ -229,10 +243,10 @@ def plot_heatmap(data, **kwargs):
229243
if threshold_value is None:
230244
threshold_value = 0.5 if self.response_method == "predict_proba" else 0
231245
title = f"{title}\nDecision threshold: {threshold_value:.2f}"
232-
figure_.suptitle(f"{title}\n{info_data_source}")
246+
self.facet_.figure.suptitle(f"{title}\n{info_data_source}")
233247

234-
for axis in ax_:
235-
axis.set(
248+
for ax in self.facet_.axes.flatten():
249+
ax.set(
236250
xlabel="Predicted label",
237251
ylabel="True label",
238252
)
@@ -242,20 +256,20 @@ def plot_heatmap(data, **kwargs):
242256
for label in self.display_labels
243257
]
244258

245-
axis.set(
259+
ax.set(
246260
xticklabels=ticklabels,
247261
yticklabels=ticklabels,
248262
)
249263

250-
axis.text(
264+
ax.text(
251265
-0.15,
252266
-0.15,
253267
"*: the positive class",
254268
fontsize=9,
255269
style="italic",
256270
verticalalignment="bottom",
257271
horizontalalignment="left",
258-
transform=axis.transAxes,
272+
transform=ax.transAxes,
259273
bbox={
260274
"boxstyle": "round",
261275
"facecolor": "white",
@@ -264,9 +278,10 @@ def plot_heatmap(data, **kwargs):
264278
},
265279
)
266280

267-
if len(ax_) == 1:
268-
ax_ = ax_[0]
269-
return facet_
281+
if len(self.ax_) == 1:
282+
self.ax_ = self.ax_[0]
283+
284+
return self.facet_
270285

271286
def _validate_subplot_by(
272287
self,

0 commit comments

Comments
 (0)