Skip to content

Commit 8ded7f4

Browse files
authored
ENH add support for Array API to mean_pinball_loss and explained_variance_score (scikit-learn#29978)
1 parent fba028b commit 8ded7f4

File tree

5 files changed

+52
-25
lines changed

5 files changed

+52
-25
lines changed

doc/modules/array_api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,13 @@ Metrics
116116
- :func:`sklearn.metrics.cluster.entropy`
117117
- :func:`sklearn.metrics.accuracy_score`
118118
- :func:`sklearn.metrics.d2_tweedie_score`
119+
- :func:`sklearn.metrics.explained_variance_score`
119120
- :func:`sklearn.metrics.f1_score`
120121
- :func:`sklearn.metrics.max_error`
121122
- :func:`sklearn.metrics.mean_absolute_error`
122123
- :func:`sklearn.metrics.mean_absolute_percentage_error`
123124
- :func:`sklearn.metrics.mean_gamma_deviance`
125+
- :func:`sklearn.metrics.mean_pinball_loss`
124126
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
125127
- :func:`sklearn.metrics.mean_squared_error`
126128
- :func:`sklearn.metrics.mean_squared_log_error`
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :func:`sklearn.metrics.explained_variance_score` and
2+
:func:`sklearn.metrics.mean_pinball_loss` now support Array API compatible inputs.
3+
by :user:`Virgil Chan <virchan>`

sklearn/metrics/_regression.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def mean_absolute_error(
288288
if multioutput == "raw_values":
289289
return output_errors
290290
elif multioutput == "uniform_average":
291-
# pass None as weights to np.average: uniform mean
291+
# pass None as weights to _average: uniform mean
292292
multioutput = None
293293

294294
# Average across the outputs (if needed).
@@ -360,35 +360,45 @@ def mean_pinball_loss(
360360
>>> from sklearn.metrics import mean_pinball_loss
361361
>>> y_true = [1, 2, 3]
362362
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.1)
363-
np.float64(0.03...)
363+
0.03...
364364
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.1)
365-
np.float64(0.3...)
365+
0.3...
366366
>>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.9)
367-
np.float64(0.3...)
367+
0.3...
368368
>>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.9)
369-
np.float64(0.03...)
369+
0.03...
370370
>>> mean_pinball_loss(y_true, y_true, alpha=0.1)
371-
np.float64(0.0)
371+
0.0
372372
>>> mean_pinball_loss(y_true, y_true, alpha=0.9)
373-
np.float64(0.0)
373+
0.0
374374
"""
375-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
376-
y_true, y_pred, multioutput
375+
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
376+
377+
_, y_true, y_pred, sample_weight, multioutput = (
378+
_check_reg_targets_with_floating_dtype(
379+
y_true, y_pred, sample_weight, multioutput, xp=xp
380+
)
377381
)
382+
378383
check_consistent_length(y_true, y_pred, sample_weight)
379384
diff = y_true - y_pred
380-
sign = (diff >= 0).astype(diff.dtype)
385+
sign = xp.astype(diff >= 0, diff.dtype)
381386
loss = alpha * sign * diff - (1 - alpha) * (1 - sign) * diff
382-
output_errors = np.average(loss, weights=sample_weight, axis=0)
387+
output_errors = _average(loss, weights=sample_weight, axis=0)
383388

384389
if isinstance(multioutput, str) and multioutput == "raw_values":
385390
return output_errors
386391

387392
if isinstance(multioutput, str) and multioutput == "uniform_average":
388-
# pass None as weights to np.average: uniform mean
393+
# pass None as weights to _average: uniform mean
389394
multioutput = None
390395

391-
return np.average(output_errors, weights=multioutput)
396+
# Average across the outputs (if needed).
397+
# The second call to `_average` should always return
398+
# a scalar array that we convert to a Python float to
399+
# consistently return the same eager evaluated value.
400+
# Therefore, `axis=None`.
401+
return float(_average(output_errors, weights=multioutput))
392402

393403

394404
@validate_params(
@@ -949,12 +959,12 @@ def _assemble_r2_explained_variance(
949959
# return scores individually
950960
return output_scores
951961
elif multioutput == "uniform_average":
952-
# Passing None as weights to np.average results is uniform mean
962+
# pass None as weights to _average: uniform mean
953963
avg_weights = None
954964
elif multioutput == "variance_weighted":
955965
avg_weights = denominator
956966
if not xp.any(nonzero_denominator):
957-
# All weights are zero, np.average would raise a ZeroDiv error.
967+
# All weights are zero, _average would raise a ZeroDiv error.
958968
# This only happens when all y are constant (or 1-element long)
959969
# Since weights are all equal, fall back to uniform weights.
960970
avg_weights = None
@@ -1083,28 +1093,32 @@ def explained_variance_score(
10831093
>>> explained_variance_score(y_true, y_pred, force_finite=False)
10841094
-inf
10851095
"""
1086-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
1087-
y_true, y_pred, multioutput
1096+
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight, multioutput)
1097+
1098+
_, y_true, y_pred, sample_weight, multioutput = (
1099+
_check_reg_targets_with_floating_dtype(
1100+
y_true, y_pred, sample_weight, multioutput, xp=xp
1101+
)
10881102
)
1103+
10891104
check_consistent_length(y_true, y_pred, sample_weight)
10901105

1091-
y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
1092-
numerator = np.average(
1106+
y_diff_avg = _average(y_true - y_pred, weights=sample_weight, axis=0)
1107+
numerator = _average(
10931108
(y_true - y_pred - y_diff_avg) ** 2, weights=sample_weight, axis=0
10941109
)
10951110

1096-
y_true_avg = np.average(y_true, weights=sample_weight, axis=0)
1097-
denominator = np.average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)
1111+
y_true_avg = _average(y_true, weights=sample_weight, axis=0)
1112+
denominator = _average((y_true - y_true_avg) ** 2, weights=sample_weight, axis=0)
10981113

10991114
return _assemble_r2_explained_variance(
11001115
numerator=numerator,
11011116
denominator=denominator,
11021117
n_outputs=y_true.shape[1],
11031118
multioutput=multioutput,
11041119
force_finite=force_finite,
1105-
xp=get_namespace(y_true)[0],
1106-
# TODO: update once Array API support is added to explained_variance_score.
1107-
device=None,
1120+
xp=xp,
1121+
device=device,
11081122
)
11091123

11101124

sklearn/metrics/tests/test_common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,10 +2084,18 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20842084
check_array_api_regression_metric_multioutput,
20852085
],
20862086
cosine_similarity: [check_array_api_metric_pairwise],
2087+
explained_variance_score: [
2088+
check_array_api_regression_metric,
2089+
check_array_api_regression_metric_multioutput,
2090+
],
20872091
mean_absolute_error: [
20882092
check_array_api_regression_metric,
20892093
check_array_api_regression_metric_multioutput,
20902094
],
2095+
mean_pinball_loss: [
2096+
check_array_api_regression_metric,
2097+
check_array_api_regression_metric_multioutput,
2098+
],
20912099
mean_squared_error: [
20922100
check_array_api_regression_metric,
20932101
check_array_api_regression_metric_multioutput,

sklearn/metrics/tests/test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def test_mean_pinball_loss_on_constant_predictions(distribution, target_quantile
566566
# Check that the loss of this constant predictor is greater or equal
567567
# than the loss of using the optimal quantile (up to machine
568568
# precision):
569-
assert pbl >= best_pbl - np.finfo(best_pbl.dtype).eps
569+
assert pbl >= best_pbl - np.finfo(np.float64).eps
570570

571571
# Check that the value of the pinball loss matches the analytical
572572
# formula.

0 commit comments

Comments
 (0)