Skip to content

Commit 9aa94ee

Browse files
author
Googler
committed
Migrate lit_nlp to sklearn v1.6.1
PiperOrigin-RevId: 740720447
1 parent 21466fa commit 9aa94ee

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

lit_nlp/components/curves_test.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def input_spec(self) -> lit_types.Spec:
5151
def output_spec(self) -> lit_types.Spec:
5252
return {
5353
'pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'),
54-
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label')
54+
'aux_pred': lit_types.MulticlassPreds(vocab=COLORS, parent='label'),
5555
}
5656

5757
def predict_minibatch(
@@ -64,10 +64,9 @@ def predict_example(ex: lit_types.JsonDict) -> tuple[float, float, float]:
6464
return TEST_DATA[x].prediction
6565

6666
for example in inputs:
67-
output.append({
68-
'pred': predict_example(example),
69-
'aux_pred': [1 / 3, 1 / 3, 1 / 3]
70-
})
67+
output.append(
68+
{'pred': predict_example(example), 'aux_pred': [1 / 3, 1 / 3, 1 / 3]}
69+
)
7170
return output
7271

7372

@@ -148,6 +147,43 @@ def test_model_output_is_missing_in_config(self):
148147
config={'Label': 'red'},
149148
)
150149

150+
@parameterized.named_parameters(
151+
dict(
152+
testcase_name='red',
153+
label='red',
154+
exp_roc=[(0.0, 0.0), (0.0, 0.5), (1.0, 0.5), (1.0, 1.0)],
155+
exp_pr=[(0.5, 0.5), (2 / 3, 1.0), (1.0, 0.5), (1.0, 0.0)],
156+
),
157+
dict(
158+
testcase_name='blue',
159+
label='blue',
160+
exp_roc=[(0.0, 0.0), (0.0, 1.0), (1.0, 1.0)],
161+
exp_pr=[
162+
(0.3333333333333333, 1.0),
163+
(0.5, 1.0),
164+
(1.0, 1.0),
165+
(1.0, 0.0),
166+
],
167+
),
168+
)
169+
def test_interpreter_honors_user_selected_label(
170+
self, label: str, exp_roc: _Curve, exp_pr: _Curve
171+
):
172+
"""Tests a happy scenario when a user doesn't specify the class label."""
173+
curves_data = self.ci.run(
174+
inputs=self.dataset.examples,
175+
model=self.model,
176+
dataset=self.dataset,
177+
config={
178+
curves.TARGET_LABEL_KEY: label,
179+
curves.TARGET_PREDICTION_KEY: 'pred',
180+
},
181+
)
182+
self.assertIn(curves.ROC_DATA, curves_data)
183+
self.assertIn(curves.PR_DATA, curves_data)
184+
self.assertEqual(curves_data[curves.ROC_DATA], exp_roc)
185+
self.assertEqual(curves_data[curves.PR_DATA], exp_pr)
186+
151187
def test_config_spec(self):
152188
"""Tests that the interpreter config has correct fields of correct type."""
153189
spec = self.ci.config_spec()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"rouge-score>=0.1.2",
3535
"sacrebleu>=2.3.1",
3636
"saliency>=0.1.3",
37-
"scikit-learn>=1.0.2",
37+
"scikit-learn>=1.6.1",
3838
"scipy>=1.10.1",
3939
"shap>=0.42.0,<0.46.0",
4040
"six>=1.16.0",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ requests>=2.31.0
3131
rouge-score>=0.1.2
3232
sacrebleu>=2.3.1
3333
saliency>=0.1.3
34-
scikit-learn>=1.0.2
34+
scikit-learn>=1.6.1
3535
scipy>=1.10.1
3636
shap>=0.42.0,<0.46.0
3737
six>=1.16.0

0 commit comments

Comments
 (0)