@@ -51,7 +51,7 @@ def input_spec(self) -> lit_types.Spec:
5151 def output_spec (self ) -> lit_types .Spec :
5252 return {
5353 'pred' : lit_types .MulticlassPreds (vocab = COLORS , parent = 'label' ),
54- 'aux_pred' : lit_types .MulticlassPreds (vocab = COLORS , parent = 'label' )
54+ 'aux_pred' : lit_types .MulticlassPreds (vocab = COLORS , parent = 'label' ),
5555 }
5656
5757 def predict_minibatch (
@@ -64,10 +64,9 @@ def predict_example(ex: lit_types.JsonDict) -> tuple[float, float, float]:
6464 return TEST_DATA [x ].prediction
6565
6666 for example in inputs :
67- output .append ({
68- 'pred' : predict_example (example ),
69- 'aux_pred' : [1 / 3 , 1 / 3 , 1 / 3 ]
70- })
67+ output .append (
68+ {'pred' : predict_example (example ), 'aux_pred' : [1 / 3 , 1 / 3 , 1 / 3 ]}
69+ )
7170 return output
7271
7372
@@ -148,6 +147,43 @@ def test_model_output_is_missing_in_config(self):
148147 config = {'Label' : 'red' },
149148 )
150149
150+ @parameterized .named_parameters (
151+ dict (
152+ testcase_name = 'red' ,
153+ label = 'red' ,
154+ exp_roc = [(0.0 , 0.0 ), (0.0 , 0.5 ), (1.0 , 0.5 ), (1.0 , 1.0 )],
155+ exp_pr = [(0.5 , 0.5 ), (2 / 3 , 1.0 ), (1.0 , 0.5 ), (1.0 , 0.0 )],
156+ ),
157+ dict (
158+ testcase_name = 'blue' ,
159+ label = 'blue' ,
160+ exp_roc = [(0.0 , 0.0 ), (0.0 , 1.0 ), (1.0 , 1.0 )],
161+ exp_pr = [
162+ (0.3333333333333333 , 1.0 ),
163+ (0.5 , 1.0 ),
164+ (1.0 , 1.0 ),
165+ (1.0 , 0.0 ),
166+ ],
167+ ),
168+ )
169+ def test_interpreter_honors_user_selected_label (
170+ self , label : str , exp_roc : _Curve , exp_pr : _Curve
171+ ):
172+ """Tests a happy scenario when a user doesn't specify the class label."""
173+ curves_data = self .ci .run (
174+ inputs = self .dataset .examples ,
175+ model = self .model ,
176+ dataset = self .dataset ,
177+ config = {
178+ curves .TARGET_LABEL_KEY : label ,
179+ curves .TARGET_PREDICTION_KEY : 'pred' ,
180+ },
181+ )
182+ self .assertIn (curves .ROC_DATA , curves_data )
183+ self .assertIn (curves .PR_DATA , curves_data )
184+ self .assertEqual (curves_data [curves .ROC_DATA ], exp_roc )
185+ self .assertEqual (curves_data [curves .PR_DATA ], exp_pr )
186+
151187 def test_config_spec (self ):
152188 """Tests that the interpreter config has correct fields of correct type."""
153189 spec = self .ci .config_spec ()
0 commit comments