@@ -42,6 +42,7 @@ def test_ranking_evaluator(self):
42
42
43
43
# Initialize RankingEvaluator
44
44
evaluator = RankingEvaluator ().setPredictionCol ("prediction" )
45
+ self .assertTrue (evaluator .isLargerBetter ())
45
46
46
47
# Evaluate the dataset using the default metric (mean average precision)
47
48
mean_average_precision = evaluator .evaluate (dataset )
@@ -94,6 +95,25 @@ def test_multilabel_classification_evaluator(self):
94
95
self .assertEqual (evaluator2 .getPredictionCol (), "prediction" )
95
96
self .assertEqual (str (evaluator ), str (evaluator2 ))
96
97
98
+ for metric in [
99
+ "subsetAccuracy" ,
100
+ "accuracy" ,
101
+ "precision" ,
102
+ "recall" ,
103
+ "f1Measure" ,
104
+ "precisionByLabel" ,
105
+ "recallByLabel" ,
106
+ "f1MeasureByLabel" ,
107
+ "microPrecision" ,
108
+ "microRecall" ,
109
+ "microF1Measure" ,
110
+ ]:
111
+ evaluator .setMetricName (metric )
112
+ self .assertTrue (evaluator .isLargerBetter ())
113
+
114
+ evaluator .setMetricName ("hammingLoss" )
115
+ self .assertTrue (not evaluator .isLargerBetter ())
116
+
97
117
def test_multiclass_classification_evaluator (self ):
98
118
dataset = self .spark .createDataFrame (
99
119
[
@@ -163,6 +183,29 @@ def test_multiclass_classification_evaluator(self):
163
183
log_loss = evaluator .evaluate (dataset )
164
184
self .assertTrue (np .allclose (log_loss , 1.0093 , atol = 1e-4 ))
165
185
186
+ for metric in [
187
+ "f1" ,
188
+ "accuracy" ,
189
+ "weightedPrecision" ,
190
+ "weightedRecall" ,
191
+ "weightedTruePositiveRate" ,
192
+ "weightedFMeasure" ,
193
+ "truePositiveRateByLabel" ,
194
+ "precisionByLabel" ,
195
+ "recallByLabel" ,
196
+ "fMeasureByLabel" ,
197
+ ]:
198
+ evaluator .setMetricName (metric )
199
+ self .assertTrue (evaluator .isLargerBetter ())
200
+ for metric in [
201
+ "weightedFalsePositiveRate" ,
202
+ "falsePositiveRateByLabel" ,
203
+ "logLoss" ,
204
+ "hammingLoss" ,
205
+ ]:
206
+ evaluator .setMetricName (metric )
207
+ self .assertTrue (not evaluator .isLargerBetter ())
208
+
166
209
def test_binary_classification_evaluator (self ):
167
210
# Define score and labels data
168
211
data = map (
@@ -180,6 +223,8 @@ def test_binary_classification_evaluator(self):
180
223
dataset = self .spark .createDataFrame (data , ["raw" , "label" , "weight" ])
181
224
182
225
evaluator = BinaryClassificationEvaluator ().setRawPredictionCol ("raw" )
226
+ self .assertTrue (evaluator .isLargerBetter ())
227
+
183
228
auc_roc = evaluator .evaluate (dataset )
184
229
self .assertTrue (np .allclose (auc_roc , 0.7083 , atol = 1e-4 ))
185
230
@@ -226,6 +271,8 @@ def test_clustering_evaluator(self):
226
271
dataset = self .spark .createDataFrame (data , ["features" , "prediction" , "weight" ])
227
272
228
273
evaluator = ClusteringEvaluator ().setPredictionCol ("prediction" )
274
+ self .assertTrue (evaluator .isLargerBetter ())
275
+
229
276
score = evaluator .evaluate (dataset )
230
277
self .assertTrue (np .allclose (score , 0.9079 , atol = 1e-4 ))
231
278
@@ -300,6 +347,13 @@ def test_regression_evaluator(self):
300
347
through_origin = evaluator_with_weights .getThroughOrigin ()
301
348
self .assertEqual (through_origin , False )
302
349
350
+ for metric in ["mse" , "rmse" , "mae" ]:
351
+ evaluator .setMetricName (metric )
352
+ self .assertTrue (not evaluator .isLargerBetter ())
353
+ for metric in ["r2" , "var" ]:
354
+ evaluator .setMetricName (metric )
355
+ self .assertTrue (evaluator .isLargerBetter ())
356
+
303
357
304
358
class EvaluatorTests (EvaluatorTestsMixin , unittest .TestCase ):
305
359
def setUp (self ) -> None :
0 commit comments