@@ -288,7 +288,7 @@ def mean_absolute_error(
288288 if multioutput == "raw_values" :
289289 return output_errors
290290 elif multioutput == "uniform_average" :
291- # pass None as weights to np.average : uniform mean
291+ # pass None as weights to _average : uniform mean
292292 multioutput = None
293293
294294 # Average across the outputs (if needed).
@@ -360,35 +360,45 @@ def mean_pinball_loss(
360360 >>> from sklearn.metrics import mean_pinball_loss
361361 >>> y_true = [1, 2, 3]
362362 >>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.1)
363- np.float64( 0.03...)
363+ 0.03...
364364 >>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.1)
365- np.float64( 0.3...)
365+ 0.3...
366366 >>> mean_pinball_loss(y_true, [0, 2, 3], alpha=0.9)
367- np.float64( 0.3...)
367+ 0.3...
368368 >>> mean_pinball_loss(y_true, [1, 2, 4], alpha=0.9)
369- np.float64( 0.03...)
369+ 0.03...
370370 >>> mean_pinball_loss(y_true, y_true, alpha=0.1)
371- np.float64( 0.0)
371+ 0.0
372372 >>> mean_pinball_loss(y_true, y_true, alpha=0.9)
373- np.float64( 0.0)
373+ 0.0
374374 """
375- y_type , y_true , y_pred , multioutput = _check_reg_targets (
376- y_true , y_pred , multioutput
375+ xp , _ = get_namespace (y_true , y_pred , sample_weight , multioutput )
376+
377+ _ , y_true , y_pred , sample_weight , multioutput = (
378+ _check_reg_targets_with_floating_dtype (
379+ y_true , y_pred , sample_weight , multioutput , xp = xp
380+ )
377381 )
382+
378383 check_consistent_length (y_true , y_pred , sample_weight )
379384 diff = y_true - y_pred
380- sign = (diff >= 0 ). astype ( diff .dtype )
385+ sign = xp . astype (diff >= 0 , diff .dtype )
381386 loss = alpha * sign * diff - (1 - alpha ) * (1 - sign ) * diff
382- output_errors = np . average (loss , weights = sample_weight , axis = 0 )
387+ output_errors = _average (loss , weights = sample_weight , axis = 0 )
383388
384389 if isinstance (multioutput , str ) and multioutput == "raw_values" :
385390 return output_errors
386391
387392 if isinstance (multioutput , str ) and multioutput == "uniform_average" :
388- # pass None as weights to np.average : uniform mean
393+ # pass None as weights to _average : uniform mean
389394 multioutput = None
390395
391- return np .average (output_errors , weights = multioutput )
396+ # Average across the outputs (if needed).
397+ # The second call to `_average` should always return
398+ # a scalar array that we convert to a Python float to
399+ # consistently return the same eager evaluated value.
400+ # Therefore, `axis=None`.
401+ return float (_average (output_errors , weights = multioutput ))
392402
393403
394404@validate_params (
@@ -949,12 +959,12 @@ def _assemble_r2_explained_variance(
949959 # return scores individually
950960 return output_scores
951961 elif multioutput == "uniform_average" :
952- # Passing None as weights to np.average results is uniform mean
962+ # pass None as weights to _average: uniform mean
953963 avg_weights = None
954964 elif multioutput == "variance_weighted" :
955965 avg_weights = denominator
956966 if not xp .any (nonzero_denominator ):
957- # All weights are zero, np.average would raise a ZeroDiv error.
967+ # All weights are zero, _average would raise a ZeroDiv error.
958968 # This only happens when all y are constant (or 1-element long)
959969 # Since weights are all equal, fall back to uniform weights.
960970 avg_weights = None
@@ -1083,28 +1093,32 @@ def explained_variance_score(
10831093 >>> explained_variance_score(y_true, y_pred, force_finite=False)
10841094 -inf
10851095 """
1086- y_type , y_true , y_pred , multioutput = _check_reg_targets (
1087- y_true , y_pred , multioutput
1096+ xp , _ , device = get_namespace_and_device (y_true , y_pred , sample_weight , multioutput )
1097+
1098+ _ , y_true , y_pred , sample_weight , multioutput = (
1099+ _check_reg_targets_with_floating_dtype (
1100+ y_true , y_pred , sample_weight , multioutput , xp = xp
1101+ )
10881102 )
1103+
10891104 check_consistent_length (y_true , y_pred , sample_weight )
10901105
1091- y_diff_avg = np . average (y_true - y_pred , weights = sample_weight , axis = 0 )
1092- numerator = np . average (
1106+ y_diff_avg = _average (y_true - y_pred , weights = sample_weight , axis = 0 )
1107+ numerator = _average (
10931108 (y_true - y_pred - y_diff_avg ) ** 2 , weights = sample_weight , axis = 0
10941109 )
10951110
1096- y_true_avg = np . average (y_true , weights = sample_weight , axis = 0 )
1097- denominator = np . average ((y_true - y_true_avg ) ** 2 , weights = sample_weight , axis = 0 )
1111+ y_true_avg = _average (y_true , weights = sample_weight , axis = 0 )
1112+ denominator = _average ((y_true - y_true_avg ) ** 2 , weights = sample_weight , axis = 0 )
10981113
10991114 return _assemble_r2_explained_variance (
11001115 numerator = numerator ,
11011116 denominator = denominator ,
11021117 n_outputs = y_true .shape [1 ],
11031118 multioutput = multioutput ,
11041119 force_finite = force_finite ,
1105- xp = get_namespace (y_true )[0 ],
1106- # TODO: update once Array API support is added to explained_variance_score.
1107- device = None ,
1120+ xp = xp ,
1121+ device = device ,
11081122 )
11091123
11101124
0 commit comments