Skip to content

Commit 311a4e0

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-50884][ML][PYTHON][CONNECT] Support isLargerBetter in Evaluator
### What changes were proposed in this pull request? Support isLargerBetter in Evaluator ### Why are the changes needed? for parity feature ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? The newly added tests pass ### Was this patch authored or co-authored using generative AI tooling? No Closes #49620 from wbo4958/isLargerBetter. Authored-by: Bobby Wang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 8f66aef commit 311a4e0

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

python/pyspark/ml/evaluation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,10 @@ def setParams(
311311
kwargs = self._input_kwargs
312312
return self._set(**kwargs)
313313

314+
def isLargerBetter(self) -> bool:
315+
"""Override this function to make it run on connect"""
316+
return True
317+
314318

315319
@inherit_doc
316320
class RegressionEvaluator(
@@ -467,6 +471,10 @@ def setParams(
467471
kwargs = self._input_kwargs
468472
return self._set(**kwargs)
469473

474+
def isLargerBetter(self) -> bool:
475+
"""Override this function to make it run on connect"""
476+
return self.getMetricName() in ["r2", "var"]
477+
470478

471479
@inherit_doc
472480
class MulticlassClassificationEvaluator(
@@ -700,6 +708,15 @@ def setParams(
700708
kwargs = self._input_kwargs
701709
return self._set(**kwargs)
702710

711+
def isLargerBetter(self) -> bool:
712+
"""Override this function to make it run on connect"""
713+
return not self.getMetricName() in [
714+
"weightedFalsePositiveRate",
715+
"falsePositiveRateByLabel",
716+
"logLoss",
717+
"hammingLoss",
718+
]
719+
703720

704721
@inherit_doc
705722
class MultilabelClassificationEvaluator(
@@ -843,6 +860,10 @@ def setParams(
843860
kwargs = self._input_kwargs
844861
return self._set(**kwargs)
845862

863+
def isLargerBetter(self) -> bool:
864+
"""Override this function to make it run on connect"""
865+
return self.getMetricName() != "hammingLoss"
866+
846867

847868
@inherit_doc
848869
class ClusteringEvaluator(
@@ -1002,6 +1023,10 @@ def setWeightCol(self, value: str) -> "ClusteringEvaluator":
10021023
"""
10031024
return self._set(weightCol=value)
10041025

1026+
def isLargerBetter(self) -> bool:
1027+
"""Override this function to make it run on connect"""
1028+
return True
1029+
10051030

10061031
@inherit_doc
10071032
class RankingEvaluator(
@@ -1138,6 +1163,10 @@ def setParams(
11381163
kwargs = self._input_kwargs
11391164
return self._set(**kwargs)
11401165

1166+
def isLargerBetter(self) -> bool:
1167+
"""Override this function to make it run on connect"""
1168+
return True
1169+
11411170

11421171
if __name__ == "__main__":
11431172
import doctest

python/pyspark/ml/tests/test_evaluation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_ranking_evaluator(self):
4242

4343
# Initialize RankingEvaluator
4444
evaluator = RankingEvaluator().setPredictionCol("prediction")
45+
self.assertTrue(evaluator.isLargerBetter())
4546

4647
# Evaluate the dataset using the default metric (mean average precision)
4748
mean_average_precision = evaluator.evaluate(dataset)
@@ -94,6 +95,25 @@ def test_multilabel_classification_evaluator(self):
9495
self.assertEqual(evaluator2.getPredictionCol(), "prediction")
9596
self.assertEqual(str(evaluator), str(evaluator2))
9697

98+
for metric in [
99+
"subsetAccuracy",
100+
"accuracy",
101+
"precision",
102+
"recall",
103+
"f1Measure",
104+
"precisionByLabel",
105+
"recallByLabel",
106+
"f1MeasureByLabel",
107+
"microPrecision",
108+
"microRecall",
109+
"microF1Measure",
110+
]:
111+
evaluator.setMetricName(metric)
112+
self.assertTrue(evaluator.isLargerBetter())
113+
114+
evaluator.setMetricName("hammingLoss")
115+
self.assertTrue(not evaluator.isLargerBetter())
116+
97117
def test_multiclass_classification_evaluator(self):
98118
dataset = self.spark.createDataFrame(
99119
[
@@ -163,6 +183,29 @@ def test_multiclass_classification_evaluator(self):
163183
log_loss = evaluator.evaluate(dataset)
164184
self.assertTrue(np.allclose(log_loss, 1.0093, atol=1e-4))
165185

186+
for metric in [
187+
"f1",
188+
"accuracy",
189+
"weightedPrecision",
190+
"weightedRecall",
191+
"weightedTruePositiveRate",
192+
"weightedFMeasure",
193+
"truePositiveRateByLabel",
194+
"precisionByLabel",
195+
"recallByLabel",
196+
"fMeasureByLabel",
197+
]:
198+
evaluator.setMetricName(metric)
199+
self.assertTrue(evaluator.isLargerBetter())
200+
for metric in [
201+
"weightedFalsePositiveRate",
202+
"falsePositiveRateByLabel",
203+
"logLoss",
204+
"hammingLoss",
205+
]:
206+
evaluator.setMetricName(metric)
207+
self.assertTrue(not evaluator.isLargerBetter())
208+
166209
def test_binary_classification_evaluator(self):
167210
# Define score and labels data
168211
data = map(
@@ -180,6 +223,8 @@ def test_binary_classification_evaluator(self):
180223
dataset = self.spark.createDataFrame(data, ["raw", "label", "weight"])
181224

182225
evaluator = BinaryClassificationEvaluator().setRawPredictionCol("raw")
226+
self.assertTrue(evaluator.isLargerBetter())
227+
183228
auc_roc = evaluator.evaluate(dataset)
184229
self.assertTrue(np.allclose(auc_roc, 0.7083, atol=1e-4))
185230

@@ -226,6 +271,8 @@ def test_clustering_evaluator(self):
226271
dataset = self.spark.createDataFrame(data, ["features", "prediction", "weight"])
227272

228273
evaluator = ClusteringEvaluator().setPredictionCol("prediction")
274+
self.assertTrue(evaluator.isLargerBetter())
275+
229276
score = evaluator.evaluate(dataset)
230277
self.assertTrue(np.allclose(score, 0.9079, atol=1e-4))
231278

@@ -300,6 +347,13 @@ def test_regression_evaluator(self):
300347
through_origin = evaluator_with_weights.getThroughOrigin()
301348
self.assertEqual(through_origin, False)
302349

350+
for metric in ["mse", "rmse", "mae"]:
351+
evaluator.setMetricName(metric)
352+
self.assertTrue(not evaluator.isLargerBetter())
353+
for metric in ["r2", "var"]:
354+
evaluator.setMetricName(metric)
355+
self.assertTrue(evaluator.isLargerBetter())
356+
303357

304358
class EvaluatorTests(EvaluatorTestsMixin, unittest.TestCase):
305359
def setUp(self) -> None:

0 commit comments

Comments
 (0)