Skip to content

Commit fcf2085

Browse files
committed
review
1 parent 42e5f58 commit fcf2085

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
618618
"root_mean_squared_log_error",
619619
}
620620

621-
# Metrics that support mixed array API inputs
621+
# Metrics that support mixed namespace/device array API inputs
622+
# Mixed mixed namespace/device support is NOT planned for pairwise metrics
622623
METRICS_SUPPORTING_MIXED_NAMESPACE = [
623624
"average_precision_score",
624625
"brier_score_loss",
@@ -2531,9 +2532,10 @@ def test_mixed_array_api_namespace_input_compliance(
25312532

25322533
data_all = {
25332534
"binary": ([0, 0, 1, 1], [0, 1, 0, 1]),
2534-
"continuous_binary": ([1, 0, 1, 0], [0.5, 0.2, 0.7, 0.6]),
2535-
"continuous_label_indicator": ([[1, 0, 1, 0]], [[0.5, 0.2, 0.7, 0.6]]),
2536-
"regression": ([2, 1, 3, 4], [2, 1, 2, 2]),
2535+
"binary_continuous": ([1, 0, 1, 0], [0.5, 0.2, 0.7, 0.6]),
2536+
"label_indicator_continuous": ([[1, 0, 1, 0]], [[0.5, 0.2, 0.7, 0.6]]),
2537+
"regression_integer": ([2, 1, 3, 4], [2, 1, 2, 2]),
2538+
"regression_continuous": ([2.1, 1.0, 3.0, 4.0], [2.2, 1.1, 2.0, 2.0]),
25372539
}
25382540
sample_weight = [1, 1, 2, 2]
25392541

@@ -2546,28 +2548,22 @@ def _get_dtype(data, xp, device):
25462548
dtype = xp.int64
25472549
return dtype
25482550

2549-
checks = ["default"]
25502551
if metric_name in CLASSIFICATION_METRICS:
25512552
# These should all accept binary label input as there are no
25522553
# `CLASSIFICATION_METRICS` that are in `METRIC_UNDEFINED_BINARY` and are
25532554
# NOT `partial`s (which we do not test for in array API compliance)
2554-
data = data_all["binary"]
2555+
data_cases = ["binary"]
25552556
elif metric_name in {**CONTINUOUS_CLASSIFICATION_METRICS, **CURVE_METRICS}:
25562557
if metric_name not in METRIC_UNDEFINED_BINARY:
2557-
data = data_all["continuous_binary"]
2558+
data_cases = ["binary_continuous"]
25582559
else:
2559-
data = data_all["continuous_label_indicator"]
2560+
data_cases = ["label_indicator_continuous"]
25602561
elif metric_name in REGRESSION_METRICS:
2561-
data = data_all["regression"]
2562-
checks.append("float")
2562+
data_cases = ["regression_integer", "regression_continuous"]
25632563

25642564
with config_context(array_api_dispatch=True):
2565-
y1, y2 = data
2566-
for check in checks:
2567-
if check == "float":
2568-
# Convert regression inputs from int to float
2569-
y1 = np.array(y1) * 0.3
2570-
y2 = np.array(y2) * 0.3
2565+
for data_case in data_cases:
2566+
y1, y2 = data_all[data_case]
25712567

25722568
dtype = _get_dtype(y1, xp_from, from_ns_and_device.device)
25732569
y1_xp = xp_from.asarray(y1, device=from_ns_and_device.device, dtype=dtype)

0 commit comments

Comments
 (0)