Skip to content

Commit ef363b6

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-50869][ML][CONNECT][PYTHON] Support evaluators on ML Connet
### What changes were proposed in this pull request? This PR adds support Evaluator on ML Connect: - org.apache.spark.ml.evaluation.RegressionEvaluator - org.apache.spark.ml.evaluation.BinaryClassificationEvaluator - org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator - org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator - org.apache.spark.ml.evaluation.ClusteringEvaluator - org.apache.spark.ml.evaluation.RankingEvaluator ### Why are the changes needed? for parity with spark classic ### Does this PR introduce _any_ user-facing change? Yes, new evaluators supported on ML connect ### How was this patch tested? The newly added tests can pass ### Was this patch authored or co-authored using generative AI tooling? No Closes #49547 from wbo4958/evaluator.ml.connect. Authored-by: Bobby Wang <wbo4958@gmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 205e382 commit ef363b6

File tree

17 files changed

+778
-87
lines changed

17 files changed

+778
-87
lines changed

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,7 @@ def __hash__(self):
11171117
"pyspark.ml.tests.connect.test_connect_pipeline",
11181118
"pyspark.ml.tests.connect.test_connect_tuning",
11191119
"pyspark.ml.tests.connect.test_parity_classification",
1120+
"pyspark.ml.tests.connect.test_parity_evaluation",
11201121
],
11211122
excluded_python_implementations=[
11221123
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml evaluators.
19+
# So register the supported evaluator here if you're trying to add a new one.
20+
21+
org.apache.spark.ml.evaluation.RegressionEvaluator
22+
org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
23+
org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
24+
org.apache.spark.ml.evaluation.MultilabelClassificationEvaluator
25+
org.apache.spark.ml.evaluation.ClusteringEvaluator
26+
org.apache.spark.ml.evaluation.RankingEvaluator

python/pyspark/ml/connect/readwrite.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def sc(self) -> "SparkContext":
3838

3939
def save(self, path: str) -> None:
4040
from pyspark.ml.wrapper import JavaModel, JavaEstimator
41+
from pyspark.ml.evaluation import JavaEvaluator
4142
from pyspark.sql.connect.session import SparkSession
4243

4344
session = SparkSession.getActiveSession()
@@ -69,6 +70,19 @@ def save(self, path: str) -> None:
6970
should_overwrite=self.shouldOverwrite,
7071
options=self.optionMap,
7172
)
73+
elif isinstance(self._instance, JavaEvaluator):
74+
evaluator = cast("JavaEvaluator", self._instance)
75+
params = serialize_ml_params(evaluator, session.client)
76+
assert isinstance(evaluator._java_obj, str)
77+
writer = pb2.MlCommand.Write(
78+
operator=pb2.MlOperator(
79+
name=evaluator._java_obj, uid=evaluator.uid, type=pb2.MlOperator.EVALUATOR
80+
),
81+
params=params,
82+
path=path,
83+
should_overwrite=self.shouldOverwrite,
84+
options=self.optionMap,
85+
)
7286
else:
7387
raise NotImplementedError(f"Unsupported writing for {self._instance}")
7488

@@ -85,6 +99,7 @@ def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
8599
def load(self, path: str) -> RL:
86100
from pyspark.sql.connect.session import SparkSession
87101
from pyspark.ml.wrapper import JavaModel, JavaEstimator
102+
from pyspark.ml.evaluation import JavaEvaluator
88103

89104
session = SparkSession.getActiveSession()
90105
assert session is not None
@@ -99,6 +114,8 @@ def load(self, path: str) -> RL:
99114
ml_type = pb2.MlOperator.MODEL
100115
elif issubclass(self._clazz, JavaEstimator):
101116
ml_type = pb2.MlOperator.ESTIMATOR
117+
elif issubclass(self._clazz, JavaEvaluator):
118+
ml_type = pb2.MlOperator.EVALUATOR
102119
else:
103120
raise ValueError(f"Unsupported reading for {java_qualified_class_name}")
104121

python/pyspark/ml/evaluation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
HasWeightCol,
3232
)
3333
from pyspark.ml.common import inherit_doc
34-
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
34+
from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_evaluate
3535
from pyspark.sql.dataframe import DataFrame
3636

3737
if TYPE_CHECKING:
@@ -128,6 +128,7 @@ class JavaEvaluator(JavaParams, Evaluator, metaclass=ABCMeta):
128128
implementations.
129129
"""
130130

131+
@try_remote_evaluate
131132
def _evaluate(self, dataset: DataFrame) -> float:
132133
"""
133134
Evaluates the output.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import unittest
20+
21+
from pyspark.ml.tests.test_evaluation import EvaluatorTestsMixin
22+
from pyspark.sql import SparkSession
23+
24+
25+
class EvaluatorParityTests(EvaluatorTestsMixin, unittest.TestCase):
26+
def setUp(self) -> None:
27+
self.spark = SparkSession.builder.remote(
28+
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
29+
).getOrCreate()
30+
31+
def test_assert_remote_mode(self):
32+
from pyspark.sql import is_remote
33+
34+
self.assertTrue(is_remote())
35+
36+
def tearDown(self) -> None:
37+
self.spark.stop()
38+
39+
40+
if __name__ == "__main__":
41+
from pyspark.ml.tests.connect.test_parity_evaluation import * # noqa: F401
42+
43+
try:
44+
import xmlrunner # type: ignore[import]
45+
46+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
47+
except ImportError:
48+
testRunner = None
49+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)