5858def _check_reg_targets (y_true , y_pred , multioutput , dtype = "numeric" , xp = None ):
5959 """Check that y_true and y_pred belong to the same regression task.
6060
61+ To reduce redundancy when calling `_find_matching_floating_dtype`,
62+ please use `_check_reg_targets_with_floating_dtype` instead.
63+
6164 Parameters
6265 ----------
63- y_true : array-like
66+ y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
67+ Ground truth (correct) target values.
6468
65- y_pred : array-like
69+ y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
70+ Estimated target values.
6671
6772 multioutput : array-like or string in ['raw_values', uniform_average',
6873 'variance_weighted'] or None
@@ -137,6 +142,71 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
137142 return y_type , y_true , y_pred , multioutput
138143
139144
145+ def _check_reg_targets_with_floating_dtype (
146+ y_true , y_pred , sample_weight , multioutput , xp = None
147+ ):
148+ """Ensures that y_true, y_pred, and sample_weight correspond to the same
149+ regression task.
150+
151+ Extends `_check_reg_targets` by automatically selecting a suitable floating-point
152+ data type for inputs using `_find_matching_floating_dtype`.
153+
154+ Use this private method only when converting inputs to array API-compatibles.
155+
156+ Parameters
157+ ----------
158+ y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
159+ Ground truth (correct) target values.
160+
161+ y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
162+ Estimated target values.
163+
164+ sample_weight : array-like of shape (n_samples,)
165+
166+ multioutput : array-like or string in ['raw_values', 'uniform_average', \
167+ 'variance_weighted'] or None
168+ None is accepted due to backward compatibility of r2_score().
169+
170+ xp : module, default=None
171+ Precomputed array namespace module. When passed, typically from a caller
172+ that has already performed inspection of its own inputs, skips array
173+ namespace inspection.
174+
175+ Returns
176+ -------
177+ type_true : one of {'continuous', 'continuous-multioutput'}
178+ The type of the true target data, as output by
179+ 'utils.multiclass.type_of_target'.
180+
181+ y_true : array-like of shape (n_samples, n_outputs)
182+ Ground truth (correct) target values.
183+
184+ y_pred : array-like of shape (n_samples, n_outputs)
185+ Estimated target values.
186+
187+ sample_weight : array-like of shape (n_samples,), default=None
188+ Sample weights.
189+
190+ multioutput : array-like of shape (n_outputs) or string in ['raw_values', \
191+ 'uniform_average', 'variance_weighted'] or None
192+ Custom output weights if ``multioutput`` is array-like or
193+ just the corresponding argument if ``multioutput`` is a
194+ correct keyword.
195+ """
196+ dtype_name = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
197+
198+ y_type , y_true , y_pred , multioutput = _check_reg_targets (
199+ y_true , y_pred , multioutput , dtype = dtype_name , xp = xp
200+ )
201+
202+ # _check_reg_targets does not accept sample_weight as input.
203+ # Convert sample_weight's data type separately to match dtype_name.
204+ if sample_weight is not None :
205+ sample_weight = xp .asarray (sample_weight , dtype = dtype_name )
206+
207+ return y_type , y_true , y_pred , sample_weight , multioutput
208+
209+
140210@validate_params (
141211 {
142212 "y_true" : ["array-like" ],
@@ -201,14 +271,14 @@ def mean_absolute_error(
201271 >>> mean_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
202272 0.85...
203273 """
204- input_arrays = [y_true , y_pred , sample_weight , multioutput ]
205- xp , _ = get_namespace (* input_arrays )
206-
207- dtype = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
274+ xp , _ = get_namespace (y_true , y_pred , sample_weight , multioutput )
208275
209- _ , y_true , y_pred , multioutput = _check_reg_targets (
210- y_true , y_pred , multioutput , dtype = dtype , xp = xp
276+ _ , y_true , y_pred , sample_weight , multioutput = (
277+ _check_reg_targets_with_floating_dtype (
278+ y_true , y_pred , sample_weight , multioutput , xp = xp
279+ )
211280 )
281+
212282 check_consistent_length (y_true , y_pred , sample_weight )
213283
214284 output_errors = _average (
@@ -398,19 +468,16 @@ def mean_absolute_percentage_error(
398468 >>> mean_absolute_percentage_error(y_true, y_pred)
399469 112589990684262.48
400470 """
401- input_arrays = [y_true , y_pred , sample_weight , multioutput ]
402- xp , _ = get_namespace (* input_arrays )
403- dtype = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
404-
405- y_type , y_true , y_pred , multioutput = _check_reg_targets (
406- y_true , y_pred , multioutput , dtype = dtype , xp = xp
471+ xp , _ = get_namespace (y_true , y_pred , sample_weight , multioutput )
472+ _ , y_true , y_pred , sample_weight , multioutput = (
473+ _check_reg_targets_with_floating_dtype (
474+ y_true , y_pred , sample_weight , multioutput , xp = xp
475+ )
407476 )
408477 check_consistent_length (y_true , y_pred , sample_weight )
409- epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = dtype )
410- y_true_abs = xp .asarray (xp .abs (y_true ), dtype = dtype )
411- mape = xp .asarray (xp .abs (y_pred - y_true ), dtype = dtype ) / xp .maximum (
412- y_true_abs , epsilon
413- )
478+ epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = y_true .dtype )
479+ y_true_abs = xp .abs (y_true )
480+ mape = xp .abs (y_pred - y_true ) / xp .maximum (y_true_abs , epsilon )
414481 output_errors = _average (mape , weights = sample_weight , axis = 0 )
415482 if isinstance (multioutput , str ):
416483 if multioutput == "raw_values" :
@@ -494,10 +561,10 @@ def mean_squared_error(
494561 0.825...
495562 """
496563 xp , _ = get_namespace (y_true , y_pred , sample_weight , multioutput )
497- dtype = _find_matching_floating_dtype ( y_true , y_pred , xp = xp )
498-
499- _ , y_true , y_pred , multioutput = _check_reg_targets (
500- y_true , y_pred , multioutput , dtype = dtype , xp = xp
564+ _ , y_true , y_pred , sample_weight , multioutput = (
565+ _check_reg_targets_with_floating_dtype (
566+ y_true , y_pred , sample_weight , multioutput , xp = xp
567+ )
501568 )
502569 check_consistent_length (y_true , y_pred , sample_weight )
503570 output_errors = _average ((y_true - y_pred ) ** 2 , axis = 0 , weights = sample_weight )
@@ -670,10 +737,9 @@ def mean_squared_log_error(
670737 0.060...
671738 """
672739 xp , _ = get_namespace (y_true , y_pred )
673- dtype = _find_matching_floating_dtype (y_true , y_pred , xp = xp )
674740
675- _ , y_true , y_pred , _ = _check_reg_targets (
676- y_true , y_pred , multioutput , dtype = dtype , xp = xp
741+ _ , y_true , y_pred , _ , _ = _check_reg_targets_with_floating_dtype (
742+ y_true , y_pred , sample_weight , multioutput , xp = xp
677743 )
678744
679745 if xp .any (y_true <= - 1 ) or xp .any (y_pred <= - 1 ):
@@ -747,10 +813,9 @@ def root_mean_squared_log_error(
747813 0.199...
748814 """
749815 xp , _ = get_namespace (y_true , y_pred )
750- dtype = _find_matching_floating_dtype (y_true , y_pred , xp = xp )
751816
752- _ , y_true , y_pred , multioutput = _check_reg_targets (
753- y_true , y_pred , multioutput , dtype = dtype , xp = xp
817+ _ , y_true , y_pred , _ , _ = _check_reg_targets_with_floating_dtype (
818+ y_true , y_pred , sample_weight , multioutput , xp = xp
754819 )
755820
756821 if xp .any (y_true <= - 1 ) or xp .any (y_pred <= - 1 ):
@@ -1188,11 +1253,12 @@ def r2_score(
11881253 y_true , y_pred , sample_weight , multioutput
11891254 )
11901255
1191- dtype = _find_matching_floating_dtype ( y_true , y_pred , sample_weight , xp = xp )
1192-
1193- _ , y_true , y_pred , multioutput = _check_reg_targets (
1194- y_true , y_pred , multioutput , dtype = dtype , xp = xp
1256+ _ , y_true , y_pred , sample_weight , multioutput = (
1257+ _check_reg_targets_with_floating_dtype (
1258+ y_true , y_pred , sample_weight , multioutput , xp = xp
1259+ )
11951260 )
1261+
11961262 check_consistent_length (y_true , y_pred , sample_weight )
11971263
11981264 if _num_samples (y_pred ) < 2 :
@@ -1201,7 +1267,7 @@ def r2_score(
12011267 return float ("nan" )
12021268
12031269 if sample_weight is not None :
1204- sample_weight = column_or_1d (sample_weight , dtype = dtype )
1270+ sample_weight = column_or_1d (sample_weight )
12051271 weight = sample_weight [:, None ]
12061272 else :
12071273 weight = 1.0
@@ -1356,8 +1422,8 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
13561422 1.4260...
13571423 """
13581424 xp , _ = get_namespace (y_true , y_pred )
1359- y_type , y_true , y_pred , _ = _check_reg_targets (
1360- y_true , y_pred , None , dtype = [ xp . float64 , xp . float32 ] , xp = xp
1425+ y_type , y_true , y_pred , sample_weight , _ = _check_reg_targets_with_floating_dtype (
1426+ y_true , y_pred , sample_weight , multioutput = None , xp = xp
13611427 )
13621428 if y_type == "continuous-multioutput" :
13631429 raise ValueError ("Multioutput not supported in mean_tweedie_deviance" )
@@ -1570,8 +1636,8 @@ def d2_tweedie_score(y_true, y_pred, *, sample_weight=None, power=0):
15701636 """
15711637 xp , _ = get_namespace (y_true , y_pred )
15721638
1573- y_type , y_true , y_pred , _ = _check_reg_targets (
1574- y_true , y_pred , None , dtype = [ xp . float64 , xp . float32 ] , xp = xp
1639+ y_type , y_true , y_pred , sample_weight , _ = _check_reg_targets_with_floating_dtype (
1640+ y_true , y_pred , sample_weight , multioutput = None , xp = xp
15751641 )
15761642 if y_type == "continuous-multioutput" :
15771643 raise ValueError ("Multioutput not supported in d2_tweedie_score" )
0 commit comments