@@ -618,7 +618,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
618618 "root_mean_squared_log_error" ,
619619}
620620
621- # Metrics that support mixed array API inputs
621+ # Metrics that support mixed namespace/device array API inputs
622+ # Mixed mixed namespace/device support is NOT planned for pairwise metrics
622623METRICS_SUPPORTING_MIXED_NAMESPACE = [
623624 "average_precision_score" ,
624625 "brier_score_loss" ,
@@ -2531,9 +2532,10 @@ def test_mixed_array_api_namespace_input_compliance(
25312532
25322533 data_all = {
25332534 "binary" : ([0 , 0 , 1 , 1 ], [0 , 1 , 0 , 1 ]),
2534- "continuous_binary" : ([1 , 0 , 1 , 0 ], [0.5 , 0.2 , 0.7 , 0.6 ]),
2535- "continuous_label_indicator" : ([[1 , 0 , 1 , 0 ]], [[0.5 , 0.2 , 0.7 , 0.6 ]]),
2536- "regression" : ([2 , 1 , 3 , 4 ], [2 , 1 , 2 , 2 ]),
2535+ "binary_continuous" : ([1 , 0 , 1 , 0 ], [0.5 , 0.2 , 0.7 , 0.6 ]),
2536+ "label_indicator_continuous" : ([[1 , 0 , 1 , 0 ]], [[0.5 , 0.2 , 0.7 , 0.6 ]]),
2537+ "regression_integer" : ([2 , 1 , 3 , 4 ], [2 , 1 , 2 , 2 ]),
2538+ "regression_continuous" : ([2.1 , 1.0 , 3.0 , 4.0 ], [2.2 , 1.1 , 2.0 , 2.0 ]),
25372539 }
25382540 sample_weight = [1 , 1 , 2 , 2 ]
25392541
@@ -2546,28 +2548,22 @@ def _get_dtype(data, xp, device):
25462548 dtype = xp .int64
25472549 return dtype
25482550
2549- checks = ["default" ]
25502551 if metric_name in CLASSIFICATION_METRICS :
25512552 # These should all accept binary label input as there are no
25522553 # `CLASSIFICATION_METRICS` that are in `METRIC_UNDEFINED_BINARY` and are
25532554 # NOT `partial`s (which we do not test for in array API compliance)
2554- data = data_all ["binary" ]
2555+ data_cases = ["binary" ]
25552556 elif metric_name in {** CONTINUOUS_CLASSIFICATION_METRICS , ** CURVE_METRICS }:
25562557 if metric_name not in METRIC_UNDEFINED_BINARY :
2557- data = data_all [ "continuous_binary " ]
2558+ data_cases = [ "binary_continuous " ]
25582559 else :
2559- data = data_all [ "continuous_label_indicator " ]
2560+ data_cases = [ "label_indicator_continuous " ]
25602561 elif metric_name in REGRESSION_METRICS :
2561- data = data_all ["regression" ]
2562- checks .append ("float" )
2562+ data_cases = ["regression_integer" , "regression_continuous" ]
25632563
25642564 with config_context (array_api_dispatch = True ):
2565- y1 , y2 = data
2566- for check in checks :
2567- if check == "float" :
2568- # Convert regression inputs from int to float
2569- y1 = np .array (y1 ) * 0.3
2570- y2 = np .array (y2 ) * 0.3
2565+ for data_case in data_cases :
2566+ y1 , y2 = data_all [data_case ]
25712567
25722568 dtype = _get_dtype (y1 , xp_from , from_ns_and_device .device )
25732569 y1_xp = xp_from .asarray (y1 , device = from_ns_and_device .device , dtype = dtype )
0 commit comments