Skip to content

Commit 487071f

Browse files
Fix scorer representation test for newer version of scikit-learn
1 parent 4a1b107 commit 487071f

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

tests/utils/test_score.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
2+
import sklearn
23
from numpy.typing import NDArray
4+
from packaging import version
35

46
from pydvl.utils.score import Scorer, compose_score, squashed_r2, squashed_variance
57

@@ -24,7 +26,13 @@ def test_scorer():
2426
"""Tests the Scorer class."""
2527
scorer = Scorer("r2")
2628
assert str(scorer) == "r2"
27-
assert repr(scorer) == "R2 (scorer=make_scorer(r2_score))"
29+
if version.parse(sklearn.__version__) >= version.parse("1.4.0"):
30+
assert (
31+
repr(scorer)
32+
== "R2 (scorer=make_scorer(r2_score, response_method='predict'))"
33+
)
34+
else:
35+
assert repr(scorer) == "R2 (scorer=make_scorer(r2_score))"
2836

2937
coef = np.array([1, 2])
3038
X = np.array([[1, 2], [3, 4]])

tests/value/shapley/test_classwise.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import numpy as np
44
import pandas as pd
55
import pytest
6+
import sklearn
67
from numpy.typing import NDArray
8+
from packaging import version
79

810
from pydvl.utils import Dataset, Utility, powerset
911
from pydvl.value import MaxChecks, ValuationResult
@@ -165,7 +167,13 @@ def test_classwise_scorer_representation():
165167

166168
scorer = ClasswiseScorer("accuracy", initial_label=0)
167169
assert str(scorer) == "classwise accuracy"
168-
assert repr(scorer) == "ClasswiseAccuracy (scorer=make_scorer(accuracy_score))"
170+
if version.parse(sklearn.__version__) >= version.parse("1.4.0"):
171+
assert (
172+
repr(scorer)
173+
== "ClasswiseAccuracy (scorer=make_scorer(accuracy_score, response_method='predict'))"
174+
)
175+
else:
176+
assert repr(scorer) == "ClasswiseAccuracy (scorer=make_scorer(accuracy_score))"
169177

170178

171179
@pytest.mark.parametrize("n_element, left_margin, right_margin", [(101, 0.3, 0.4)])

0 commit comments

Comments
 (0)