Skip to content

Commit 33f08f1

Browse files
virchanadrinjalaliogrisel
authored
ENH Reduce redundancy in floating type checks for Array API support in _regression.py (scikit-learn#30128)
Co-authored-by: Adrin Jalali <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent f28f5f9 commit 33f08f1

File tree

2 files changed

+106
-40
lines changed

2 files changed

+106
-40
lines changed

sklearn/metrics/_regression.py

Lines changed: 104 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,16 @@
5858
def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
5959
"""Check that y_true and y_pred belong to the same regression task.
6060
61+
To reduce redundancy when calling `_find_matching_floating_dtype`,
62+
please use `_check_reg_targets_with_floating_dtype` instead.
63+
6164
Parameters
6265
----------
63-
y_true : array-like
66+
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
67+
Ground truth (correct) target values.
6468
65-
y_pred : array-like
69+
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
70+
Estimated target values.
6671
6772
multioutput : array-like or string in ['raw_values', uniform_average',
6873
'variance_weighted'] or None
@@ -137,6 +142,71 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
137142
return y_type, y_true, y_pred, multioutput
138143

139144

145+
def _check_reg_targets_with_floating_dtype(
146+
y_true, y_pred, sample_weight, multioutput, xp=None
147+
):
148+
"""Ensures that y_true, y_pred, and sample_weight correspond to the same
149+
regression task.
150+
151+
Extends `_check_reg_targets` by automatically selecting a suitable floating-point
152+
data type for inputs using `_find_matching_floating_dtype`.
153+
154+
Use this private method only when converting inputs to array API-compatibles.
155+
156+
Parameters
157+
----------
158+
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
159+
Ground truth (correct) target values.
160+
161+
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
162+
Estimated target values.
163+
164+
sample_weight : array-like of shape (n_samples,)
165+
166+
multioutput : array-like or string in ['raw_values', 'uniform_average', \
167+
'variance_weighted'] or None
168+
None is accepted due to backward compatibility of r2_score().
169+
170+
xp : module, default=None
171+
Precomputed array namespace module. When passed, typically from a caller
172+
that has already performed inspection of its own inputs, skips array
173+
namespace inspection.
174+
175+
Returns
176+
-------
177+
type_true : one of {'continuous', 'continuous-multioutput'}
178+
The type of the true target data, as output by
179+
'utils.multiclass.type_of_target'.
180+
181+
y_true : array-like of shape (n_samples, n_outputs)
182+
Ground truth (correct) target values.
183+
184+
y_pred : array-like of shape (n_samples, n_outputs)
185+
Estimated target values.
186+
187+
sample_weight : array-like of shape (n_samples,), default=None
188+
Sample weights.
189+
190+
multioutput : array-like of shape (n_outputs) or string in ['raw_values', \
191+
'uniform_average', 'variance_weighted'] or None
192+
Custom output weights if ``multioutput`` is array-like or
193+
just the corresponding argument if ``multioutput`` is a
194+
correct keyword.
195+
"""
196+
dtype_name = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
197+
198+
y_type, y_true, y_pred, multioutput = _check_reg_targets(
199+
y_true, y_pred, multioutput, dtype=dtype_name, xp=xp
200+
)
201+
202+
# _check_reg_targets does not accept sample_weight as input.
203+
# Convert sample_weight's data type separately to match dtype_name.
204+
if sample_weight is not None:
205+
sample_weight = xp.asarray(sample_weight, dtype=dtype_name)
206+
207+
return y_type, y_true, y_pred, sample_weight, multioutput
208+
209+
140210
@validate_params(
141211
{
142212
"y_true": ["array-like"],
@@ -201,14 +271,14 @@ def mean_absolute_error(
201271
>>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
202272
0.85...
203273
"""
204-
input_arrays = [y_true, y_pred, sample_weight, multioutput]
205-
xp, _ = get_namespace(*input_arrays)
206-
207-
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
274+
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
208275

209-
_, y_true, y_pred, multioutput = _check_reg_targets(
210-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
276+
_, y_true, y_pred, sample_weight, multioutput = (
277+
_check_reg_targets_with_floating_dtype(
278+
y_true, y_pred, sample_weight, multioutput, xp=xp
279+
)
211280
)
281+
212282
check_consistent_length(y_true, y_pred, sample_weight)
213283

214284
output_errors = _average(
@@ -398,19 +468,16 @@ def mean_absolute_percentage_error(
398468
>>> mean_absolute_percentage_error(y_true, y_pred)
399469
112589990684262.48
400470
"""
401-
input_arrays = [y_true, y_pred, sample_weight, multioutput]
402-
xp, _ = get_namespace(*input_arrays)
403-
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
404-
405-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
406-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
471+
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
472+
_, y_true, y_pred, sample_weight, multioutput = (
473+
_check_reg_targets_with_floating_dtype(
474+
y_true, y_pred, sample_weight, multioutput, xp=xp
475+
)
407476
)
408477
check_consistent_length(y_true, y_pred, sample_weight)
409-
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
410-
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
411-
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.maximum(
412-
y_true_abs, epsilon
413-
)
478+
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=y_true.dtype)
479+
y_true_abs = xp.abs(y_true)
480+
mape = xp.abs(y_pred - y_true) / xp.maximum(y_true_abs, epsilon)
414481
output_errors = _average(mape, weights=sample_weight, axis=0)
415482
if isinstance(multioutput, str):
416483
if multioutput == "raw_values":
@@ -494,10 +561,10 @@ def mean_squared_error(
494561
0.825...
495562
"""
496563
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
497-
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
498-
499-
_, y_true, y_pred, multioutput = _check_reg_targets(
500-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
564+
_, y_true, y_pred, sample_weight, multioutput = (
565+
_check_reg_targets_with_floating_dtype(
566+
y_true, y_pred, sample_weight, multioutput, xp=xp
567+
)
501568
)
502569
check_consistent_length(y_true, y_pred, sample_weight)
503570
output_errors = _average((y_true - y_pred) ** 2, axis=0, weights=sample_weight)
@@ -670,10 +737,9 @@ def mean_squared_log_error(
670737
0.060...
671738
"""
672739
xp, _ = get_namespace(y_true, y_pred)
673-
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
674740

675-
_, y_true, y_pred, _ = _check_reg_targets(
676-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
741+
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
742+
y_true, y_pred, sample_weight, multioutput, xp=xp
677743
)
678744

679745
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
@@ -747,10 +813,9 @@ def root_mean_squared_log_error(
747813
0.199...
748814
"""
749815
xp, _ = get_namespace(y_true, y_pred)
750-
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
751816

752-
_, y_true, y_pred, multioutput = _check_reg_targets(
753-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
817+
_, y_true, y_pred, _, _ = _check_reg_targets_with_floating_dtype(
818+
y_true, y_pred, sample_weight, multioutput, xp=xp
754819
)
755820

756821
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
@@ -1188,11 +1253,12 @@ def r2_score(
11881253
y_true, y_pred, sample_weight, multioutput
11891254
)
11901255

1191-
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
1192-
1193-
_, y_true, y_pred, multioutput = _check_reg_targets(
1194-
y_true, y_pred, multioutput, dtype=dtype, xp=xp
1256+
_, y_true, y_pred, sample_weight, multioutput = (
1257+
_check_reg_targets_with_floating_dtype(
1258+
y_true, y_pred, sample_weight, multioutput, xp=xp
1259+
)
11951260
)
1261+
11961262
check_consistent_length(y_true, y_pred, sample_weight)
11971263

11981264
if _num_samples(y_pred) < 2:
@@ -1201,7 +1267,7 @@ def r2_score(
12011267
return float("nan")
12021268

12031269
if sample_weight is not None:
1204-
sample_weight = column_or_1d(sample_weight, dtype=dtype)
1270+
sample_weight = column_or_1d(sample_weight)
12051271
weight = sample_weight[:, None]
12061272
else:
12071273
weight = 1.0
@@ -1356,8 +1422,8 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
13561422
1.4260...
13571423
"""
13581424
xp, _ = get_namespace(y_true, y_pred)
1359-
y_type, y_true, y_pred, _ = _check_reg_targets(
1360-
y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp
1425+
y_type, y_true, y_pred, sample_weight, _ = _check_reg_targets_with_floating_dtype(
1426+
y_true, y_pred, sample_weight, multioutput=None, xp=xp
13611427
)
13621428
if y_type == "continuous-multioutput":
13631429
raise ValueError("Multioutput not supported in mean_tweedie_deviance")
@@ -1570,8 +1636,8 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
15701636
"""
15711637
xp, _ = get_namespace(y_true, y_pred)
15721638

1573-
y_type, y_true, y_pred, _ = _check_reg_targets(
1574-
y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp
1639+
y_type, y_true, y_pred, sample_weight, _ = _check_reg_targets_with_floating_dtype(
1640+
y_true, y_pred, sample_weight, multioutput=None, xp=xp
15751641
)
15761642
if y_type == "continuous-multioutput":
15771643
raise ValueError("Multioutput not supported in d2_tweedie_score")

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,8 @@ def _require_positive_targets(y1, y2):
583583
def _require_log1p_targets(y1, y2):
584584
"""Make targets strictly larger than -1"""
585585
offset = abs(min(y1.min(), y2.min())) - 0.99
586-
y1 = y1.astype(float)
587-
y2 = y2.astype(float)
586+
y1 = y1.astype(np.float64)
587+
y2 = y2.astype(np.float64)
588588
y1 += offset
589589
y2 += offset
590590
return y1, y2

0 commit comments

Comments
 (0)