Skip to content

Commit 762599c

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-50940][ML][PYTHON][CONNECT] Adds support CrossValidator/CrossValidatorModel on connect
### What changes were proposed in this pull request? Support CrossValidator/CrossValidatorModel on connect ### Why are the changes needed? for parity feature ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? The newly added test pass ### Was this patch authored or co-authored using generative AI tooling? No Closes #49644 from wbo4958/cv. Authored-by: Bobby Wang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent e0437e0 commit 762599c

File tree

6 files changed

+224
-15
lines changed

6 files changed

+224
-15
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@ def __hash__(self):
675675
"pyspark.ml.tests.test_param",
676676
"pyspark.ml.tests.test_persistence",
677677
"pyspark.ml.tests.test_pipeline",
678+
"pyspark.ml.tests.test_tuning",
678679
"pyspark.ml.tests.test_stat",
679680
"pyspark.ml.tests.test_training_summary",
680681
"pyspark.ml.tests.tuning.test_tuning",
@@ -1127,6 +1128,7 @@ def __hash__(self):
11271128
"pyspark.ml.tests.connect.test_parity_evaluation",
11281129
"pyspark.ml.tests.connect.test_parity_feature",
11291130
"pyspark.ml.tests.connect.test_parity_pipeline",
1131+
"pyspark.ml.tests.connect.test_parity_tuning",
11301132
],
11311133
excluded_python_implementations=[
11321134
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and

python/pyspark/ml/connect/readwrite.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
1817
import warnings
19-
from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any
18+
from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
2019

2120
import pyspark.sql.connect.proto as pb2
2221
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
22+
from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
2323
from pyspark.ml.util import MLWriter, MLReader, RL
2424
from pyspark.ml.wrapper import JavaWrapper
2525

@@ -29,6 +29,19 @@
2929
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
3030

3131

32+
class RemoteCrossValidatorModelWriter(CrossValidatorModelWriter):
33+
def __init__(
34+
self,
35+
instance: "CrossValidatorModel",
36+
optionMap: Dict[str, Any] = {},
37+
session: Optional["SparkSession"] = None,
38+
):
39+
super(RemoteCrossValidatorModelWriter, self).__init__(instance)
40+
self.instance = instance
41+
self.optionMap = optionMap
42+
self.session(session) # type: ignore[arg-type]
43+
44+
3245
class RemoteMLWriter(MLWriter):
3346
def __init__(self, instance: "JavaMLWritable") -> None:
3447
super().__init__()
@@ -63,6 +76,7 @@ def saveInstance(
6376
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
6477
from pyspark.ml.evaluation import JavaEvaluator
6578
from pyspark.ml.pipeline import Pipeline, PipelineModel
79+
from pyspark.ml.tuning import CrossValidator
6680

6781
# Spark Connect ML is built on scala Spark.ML, that means we're only
6882
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -126,6 +140,21 @@ def saveInstance(
126140
path,
127141
)
128142

143+
elif isinstance(instance, CrossValidator):
144+
from pyspark.ml.tuning import CrossValidatorWriter
145+
146+
if shouldOverwrite:
147+
# TODO(SPARK-50954): Support client side model path overwrite
148+
warnings.warn("Overwrite doesn't take effect for CrossValidator")
149+
cv_writer = CrossValidatorWriter(instance)
150+
cv_writer.session(session) # type: ignore[arg-type]
151+
cv_writer.save(path)
152+
elif isinstance(instance, CrossValidatorModel):
153+
if shouldOverwrite:
154+
# TODO(SPARK-50954): Support client side model path overwrite
155+
warnings.warn("Overwrite doesn't take effect for CrossValidatorModel")
156+
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, session)
157+
cvm_writer.save(path)
129158
else:
130159
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
131160

@@ -153,6 +182,7 @@ def loadInstance(
153182
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
154183
from pyspark.ml.evaluation import JavaEvaluator
155184
from pyspark.ml.pipeline import Pipeline, PipelineModel
185+
from pyspark.ml.tuning import CrossValidator
156186

157187
if (
158188
issubclass(clazz, JavaModel)
@@ -217,5 +247,19 @@ def _get_class() -> Type[RL]:
217247
else:
218248
return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)
219249

250+
elif issubclass(clazz, CrossValidator):
251+
from pyspark.ml.tuning import CrossValidatorReader
252+
253+
cv_reader = CrossValidatorReader(CrossValidator)
254+
cv_reader.session(session)
255+
return cv_reader.load(path)
256+
257+
elif issubclass(clazz, CrossValidatorModel):
258+
from pyspark.ml.tuning import CrossValidatorModelReader
259+
260+
cvm_reader = CrossValidatorModelReader(CrossValidator)
261+
cvm_reader.session(session)
262+
return cvm_reader.load(path)
263+
220264
else:
221265
raise RuntimeError(f"Unsupported read for {clazz}")
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.ml.classification import LogisticRegression
21+
from pyspark.ml.connect.readwrite import RemoteCrossValidatorModelWriter
22+
from pyspark.ml.linalg import Vectors
23+
from pyspark.ml.tests.test_tuning import TuningTestsMixin
24+
from pyspark.ml.tuning import CrossValidatorModel
25+
from pyspark.testing.connectutils import ReusedConnectTestCase
26+
27+
28+
class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase):
29+
def test_remote_cross_validator_model_writer(self):
30+
df = self.spark.createDataFrame(
31+
[
32+
(1.0, 1.0, Vectors.dense(0.0, 5.0)),
33+
(0.0, 2.0, Vectors.dense(1.0, 2.0)),
34+
(1.0, 3.0, Vectors.dense(2.0, 1.0)),
35+
(0.0, 4.0, Vectors.dense(3.0, 3.0)),
36+
],
37+
["label", "weight", "features"],
38+
)
39+
40+
lor = LogisticRegression()
41+
lor_model = lor.fit(df)
42+
cv_model = CrossValidatorModel(lor_model)
43+
writer = RemoteCrossValidatorModelWriter(cv_model, {"a": "b"}, self.spark)
44+
self.assertEqual(writer.optionMap["a"], "b")
45+
46+
47+
if __name__ == "__main__":
48+
from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401
49+
50+
try:
51+
import xmlrunner # type: ignore[import]
52+
53+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
54+
except ImportError:
55+
testRunner = None
56+
unittest.main(testRunner=testRunner, verbosity=2)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import tempfile
20+
import unittest
21+
22+
import numpy as np
23+
24+
from pyspark.ml.evaluation import BinaryClassificationEvaluator
25+
from pyspark.ml.linalg import Vectors
26+
from pyspark.ml.classification import LogisticRegression
27+
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
28+
from pyspark.testing.sqlutils import ReusedSQLTestCase
29+
30+
31+
class TuningTestsMixin:
32+
def test_cross_validator(self):
33+
dataset = self.spark.createDataFrame(
34+
[
35+
(Vectors.dense([0.0]), 0.0),
36+
(Vectors.dense([0.4]), 1.0),
37+
(Vectors.dense([0.5]), 0.0),
38+
(Vectors.dense([0.6]), 1.0),
39+
(Vectors.dense([1.0]), 1.0),
40+
]
41+
* 10,
42+
["features", "label"],
43+
)
44+
lr = LogisticRegression()
45+
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
46+
evaluator = BinaryClassificationEvaluator()
47+
cv = CrossValidator(
48+
estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, parallelism=1
49+
)
50+
51+
self.assertEqual(cv.getEstimator(), lr)
52+
self.assertEqual(cv.getEvaluator(), evaluator)
53+
self.assertEqual(cv.getParallelism(), 1)
54+
self.assertEqual(cv.getEstimatorParamMaps(), grid)
55+
56+
model = cv.fit(dataset)
57+
self.assertEqual(model.getEstimator(), lr)
58+
self.assertEqual(model.getEvaluator(), evaluator)
59+
self.assertEqual(model.getEstimatorParamMaps(), grid)
60+
self.assertTrue(np.isclose(model.avgMetrics[0], 0.5, atol=1e-4))
61+
62+
output = model.transform(dataset)
63+
self.assertEqual(
64+
output.columns, ["features", "label", "rawPrediction", "probability", "prediction"]
65+
)
66+
self.assertEqual(output.count(), 50)
67+
68+
# save & load
69+
with tempfile.TemporaryDirectory(prefix="cv_lr") as d:
70+
path1 = os.path.join(d, "cv")
71+
cv.write().save(path1)
72+
cv2 = CrossValidator.load(path1)
73+
self.assertEqual(str(cv), str(cv2))
74+
self.assertEqual(str(cv.getEstimator()), str(cv2.getEstimator()))
75+
self.assertEqual(str(cv.getEvaluator()), str(cv2.getEvaluator()))
76+
77+
path2 = os.path.join(d, "cv_model")
78+
model.write().save(path2)
79+
model2 = CrossValidatorModel.load(path2)
80+
self.assertEqual(str(model), str(model2))
81+
self.assertEqual(str(model.getEstimator()), str(model2.getEstimator()))
82+
self.assertEqual(str(model.getEvaluator()), str(model2.getEvaluator()))
83+
84+
85+
class TuningTests(TuningTestsMixin, ReusedSQLTestCase):
86+
pass
87+
88+
89+
if __name__ == "__main__":
90+
from pyspark.ml.tests.test_tuning import * # noqa: F401
91+
92+
try:
93+
import xmlrunner
94+
95+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
96+
except ImportError:
97+
testRunner = None
98+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/ml/tuning.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
MLWriter,
5454
JavaMLReader,
5555
JavaMLWriter,
56+
try_remote_write,
57+
try_remote_read,
5658
)
5759
from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
5860
from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
@@ -386,7 +388,7 @@ def is_java_convertible(instance: _ValidatorParams) -> bool:
386388
def saveImpl(
387389
path: str,
388390
instance: _ValidatorParams,
389-
sc: "SparkContext",
391+
sc: Union["SparkContext", "SparkSession"],
390392
extraMetadata: Optional[Dict[str, Any]] = None,
391393
) -> None:
392394
numParamsNotJson = 0
@@ -430,7 +432,7 @@ def saveImpl(
430432

431433
@staticmethod
432434
def load(
433-
path: str, sc: "SparkContext", metadata: Dict[str, Any]
435+
path: str, sc: Union["SparkContext", "SparkSession"], metadata: Dict[str, Any]
434436
) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]:
435437
evaluatorPath = os.path.join(path, "evaluator")
436438
evaluator: Evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
@@ -513,12 +515,12 @@ def __init__(self, cls: Type["CrossValidator"]):
513515
self.cls = cls
514516

515517
def load(self, path: str) -> "CrossValidator":
516-
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
518+
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
517519
if not DefaultParamsReader.isPythonParamsInstance(metadata):
518520
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
519521
else:
520522
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
521-
path, self.sc, metadata
523+
path, self.sparkSession, metadata
522524
)
523525
cv = CrossValidator(
524526
estimator=estimator, estimatorParamMaps=estimatorParamMaps, evaluator=evaluator
@@ -536,7 +538,7 @@ def __init__(self, instance: "CrossValidator"):
536538

537539
def saveImpl(self, path: str) -> None:
538540
_ValidatorSharedReadWrite.validateParams(self.instance)
539-
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
541+
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sparkSession)
540542

541543

542544
@inherit_doc
@@ -546,16 +548,18 @@ def __init__(self, cls: Type["CrossValidatorModel"]):
546548
self.cls = cls
547549

548550
def load(self, path: str) -> "CrossValidatorModel":
549-
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
551+
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
550552
if not DefaultParamsReader.isPythonParamsInstance(metadata):
551553
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
552554
else:
553555
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
554-
path, self.sc, metadata
556+
path, self.sparkSession, metadata
555557
)
556558
numFolds = metadata["paramMap"]["numFolds"]
557559
bestModelPath = os.path.join(path, "bestModel")
558-
bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
560+
bestModel: Model = DefaultParamsReader.loadParamsInstance(
561+
bestModelPath, self.sparkSession
562+
)
559563
avgMetrics = metadata["avgMetrics"]
560564
if "stdMetrics" in metadata:
561565
stdMetrics = metadata["stdMetrics"]
@@ -571,7 +575,7 @@ def load(self, path: str) -> "CrossValidatorModel":
571575
path, "subModels", f"fold{splitIndex}", f"{paramIndex}"
572576
)
573577
subModels[splitIndex][paramIndex] = DefaultParamsReader.loadParamsInstance(
574-
modelPath, self.sc
578+
modelPath, self.sparkSession
575579
)
576580
else:
577581
subModels = None
@@ -608,7 +612,9 @@ def saveImpl(self, path: str) -> None:
608612
if instance.stdMetrics:
609613
extraMetadata["stdMetrics"] = instance.stdMetrics
610614

611-
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
615+
_ValidatorSharedReadWrite.saveImpl(
616+
path, instance, self.sparkSession, extraMetadata=extraMetadata
617+
)
612618
bestModelPath = os.path.join(path, "bestModel")
613619
cast(MLWritable, instance.bestModel).save(bestModelPath)
614620
if persistSubModels:
@@ -845,7 +851,7 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
845851
train = datasets[i][0].cache()
846852

847853
tasks = map(
848-
inheritable_thread_target,
854+
inheritable_thread_target(dataset.sparkSession),
849855
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
850856
)
851857
for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
@@ -939,6 +945,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidator":
939945
return newCV
940946

941947
@since("2.3.0")
948+
@try_remote_write
942949
def write(self) -> MLWriter:
943950
"""Returns an MLWriter instance for this ML instance."""
944951
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -947,6 +954,7 @@ def write(self) -> MLWriter:
947954

948955
@classmethod
949956
@since("2.3.0")
957+
@try_remote_read
950958
def read(cls) -> CrossValidatorReader:
951959
"""Returns an MLReader instance for this class."""
952960
return CrossValidatorReader(cls)
@@ -1077,6 +1085,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "CrossValidatorModel":
10771085
)
10781086

10791087
@since("2.3.0")
1088+
@try_remote_write
10801089
def write(self) -> MLWriter:
10811090
"""Returns an MLWriter instance for this ML instance."""
10821091
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1085,6 +1094,7 @@ def write(self) -> MLWriter:
10851094

10861095
@classmethod
10871096
@since("2.3.0")
1097+
@try_remote_read
10881098
def read(cls) -> CrossValidatorModelReader:
10891099
"""Returns an MLReader instance for this class."""
10901100
return CrossValidatorModelReader(cls)

0 commit comments

Comments
 (0)