Skip to content

Commit 3f27cb7

Browse files
wbo4958zhengruifeng
andcommitted
[SPARK-50941][ML][PYTHON][CONNECT] add supports for TrainValidationSplit
### What changes were proposed in this pull request? This PR adds support for TrainValidationSplit and TrainValidationSplitModel on Connect ### Why are the changes needed? new feature parity ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? The CI passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #49688 from wbo4958/train_validation_split. Lead-authored-by: Bobby Wang <wbo4958@gmail.com> Co-authored-by: Ruifeng Zheng <ruifengz@foxmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org> (cherry picked from commit 9d0e888) Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 40b8dfa commit 3f27cb7

File tree

3 files changed

+125
-13
lines changed

3 files changed

+125
-13
lines changed

python/pyspark/ml/connect/readwrite.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020
import pyspark.sql.connect.proto as pb2
2121
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
22-
from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
22+
from pyspark.ml.tuning import (
23+
CrossValidatorModelWriter,
24+
CrossValidatorModel,
25+
TrainValidationSplitModel,
26+
TrainValidationSplitModelWriter,
27+
)
2328
from pyspark.ml.util import MLWriter, MLReader, RL
2429
from pyspark.ml.wrapper import JavaWrapper
2530

@@ -42,6 +47,19 @@ def __init__(
4247
self.session(session) # type: ignore[arg-type]
4348

4449

50+
class RemoteTrainValidationSplitModelWriter(TrainValidationSplitModelWriter):
51+
def __init__(
52+
self,
53+
instance: "TrainValidationSplitModel",
54+
optionMap: Dict[str, Any] = {},
55+
session: Optional["SparkSession"] = None,
56+
):
57+
super(RemoteTrainValidationSplitModelWriter, self).__init__(instance)
58+
self.instance = instance
59+
self.optionMap = optionMap
60+
self.session(session) # type: ignore[arg-type]
61+
62+
4563
class RemoteMLWriter(MLWriter):
4664
def __init__(self, instance: "JavaMLWritable") -> None:
4765
super().__init__()
@@ -76,7 +94,7 @@ def saveInstance(
7694
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
7795
from pyspark.ml.evaluation import JavaEvaluator
7896
from pyspark.ml.pipeline import Pipeline, PipelineModel
79-
from pyspark.ml.tuning import CrossValidator
97+
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
8098

8199
# Spark Connect ML is built on scala Spark.ML, that means we're only
82100
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -155,6 +173,20 @@ def saveInstance(
155173
warnings.warn("Overwrite doesn't take effect for CrossValidatorModel")
156174
cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, session)
157175
cvm_writer.save(path)
176+
elif isinstance(instance, TrainValidationSplit):
177+
from pyspark.ml.tuning import TrainValidationSplitWriter
178+
179+
if shouldOverwrite:
180+
# TODO(SPARK-50954): Support client side model path overwrite
181+
warnings.warn("Overwrite doesn't take effect for TrainValidationSplit")
182+
tvs_writer = TrainValidationSplitWriter(instance)
183+
tvs_writer.save(path)
184+
elif isinstance(instance, TrainValidationSplitModel):
185+
if shouldOverwrite:
186+
# TODO(SPARK-50954): Support client side model path overwrite
187+
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
188+
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
189+
tvsm_writer.save(path)
158190
else:
159191
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
160192

@@ -182,7 +214,7 @@ def loadInstance(
182214
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
183215
from pyspark.ml.evaluation import JavaEvaluator
184216
from pyspark.ml.pipeline import Pipeline, PipelineModel
185-
from pyspark.ml.tuning import CrossValidator
217+
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
186218

187219
if (
188220
issubclass(clazz, JavaModel)
@@ -261,5 +293,19 @@ def _get_class() -> Type[RL]:
261293
cvm_reader.session(session)
262294
return cvm_reader.load(path)
263295

296+
elif issubclass(clazz, TrainValidationSplit):
297+
from pyspark.ml.tuning import TrainValidationSplitReader
298+
299+
tvs_reader = TrainValidationSplitReader(TrainValidationSplit)
300+
tvs_reader.session(session)
301+
return tvs_reader.load(path)
302+
303+
elif issubclass(clazz, TrainValidationSplitModel):
304+
from pyspark.ml.tuning import TrainValidationSplitModelReader
305+
306+
tvs_reader = TrainValidationSplitModelReader(TrainValidationSplitModel)
307+
tvs_reader.session(session)
308+
return tvs_reader.load(path)
309+
264310
else:
265311
raise RuntimeError(f"Unsupported read for {clazz}")

python/pyspark/ml/tests/test_tuning.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,69 @@
2424
from pyspark.ml.evaluation import BinaryClassificationEvaluator
2525
from pyspark.ml.linalg import Vectors
2626
from pyspark.ml.classification import LogisticRegression
27-
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
27+
from pyspark.ml.tuning import (
28+
ParamGridBuilder,
29+
CrossValidator,
30+
CrossValidatorModel,
31+
TrainValidationSplit,
32+
TrainValidationSplitModel,
33+
)
2834
from pyspark.testing.sqlutils import ReusedSQLTestCase
2935

3036

3137
class TuningTestsMixin:
38+
def test_train_validation_split(self):
39+
dataset = self.spark.createDataFrame(
40+
[
41+
(Vectors.dense([0.0]), 0.0),
42+
(Vectors.dense([0.4]), 1.0),
43+
(Vectors.dense([0.5]), 0.0),
44+
(Vectors.dense([0.6]), 1.0),
45+
(Vectors.dense([1.0]), 1.0),
46+
]
47+
* 10, # Repeat the data 10 times
48+
["features", "label"],
49+
).repartition(1)
50+
51+
lr = LogisticRegression()
52+
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
53+
evaluator = BinaryClassificationEvaluator()
54+
55+
tvs = TrainValidationSplit(
56+
estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, parallelism=1, seed=42
57+
)
58+
self.assertEqual(tvs.getEstimator(), lr)
59+
self.assertEqual(tvs.getEvaluator(), evaluator)
60+
self.assertEqual(tvs.getParallelism(), 1)
61+
self.assertEqual(tvs.getEstimatorParamMaps(), grid)
62+
63+
tvs_model = tvs.fit(dataset)
64+
65+
# Access the train ratio
66+
self.assertEqual(tvs_model.getTrainRatio(), 0.75)
67+
print("----------- ", tvs_model.validationMetrics)
68+
self.assertTrue(np.isclose(tvs_model.validationMetrics[0], 0.5, atol=1e-4))
69+
self.assertTrue(np.isclose(tvs_model.validationMetrics[1], 0.8857142857142857, atol=1e-4))
70+
71+
evaluation_score = evaluator.evaluate(tvs_model.transform(dataset))
72+
self.assertTrue(np.isclose(evaluation_score, 0.8333333333333333, atol=1e-4))
73+
74+
# save & load
75+
with tempfile.TemporaryDirectory(prefix="train_validation_split") as d:
76+
path1 = os.path.join(d, "cv")
77+
tvs.write().save(path1)
78+
tvs2 = TrainValidationSplit.load(path1)
79+
self.assertEqual(str(tvs), str(tvs2))
80+
self.assertEqual(str(tvs.getEstimator()), str(tvs2.getEstimator()))
81+
self.assertEqual(str(tvs.getEvaluator()), str(tvs2.getEvaluator()))
82+
83+
path2 = os.path.join(d, "cv_model")
84+
tvs_model.write().save(path2)
85+
model2 = TrainValidationSplitModel.load(path2)
86+
self.assertEqual(str(tvs_model), str(model2))
87+
self.assertEqual(str(tvs_model.getEstimator()), str(model2.getEstimator()))
88+
self.assertEqual(str(tvs_model.getEvaluator()), str(model2.getEvaluator()))
89+
3290
def test_cross_validator(self):
3391
dataset = self.spark.createDataFrame(
3492
[

python/pyspark/ml/tuning.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,12 +1186,12 @@ def __init__(self, cls: Type["TrainValidationSplit"]):
11861186
self.cls = cls
11871187

11881188
def load(self, path: str) -> "TrainValidationSplit":
1189-
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
1189+
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
11901190
if not DefaultParamsReader.isPythonParamsInstance(metadata):
11911191
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
11921192
else:
11931193
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
1194-
path, self.sc, metadata
1194+
path, self.sparkSession, metadata
11951195
)
11961196
tvs = TrainValidationSplit(
11971197
estimator=estimator, estimatorParamMaps=estimatorParamMaps, evaluator=evaluator
@@ -1209,7 +1209,7 @@ def __init__(self, instance: "TrainValidationSplit"):
12091209

12101210
def saveImpl(self, path: str) -> None:
12111211
_ValidatorSharedReadWrite.validateParams(self.instance)
1212-
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
1212+
_ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sparkSession)
12131213

12141214

12151215
@inherit_doc
@@ -1219,15 +1219,17 @@ def __init__(self, cls: Type["TrainValidationSplitModel"]):
12191219
self.cls = cls
12201220

12211221
def load(self, path: str) -> "TrainValidationSplitModel":
1222-
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
1222+
metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
12231223
if not DefaultParamsReader.isPythonParamsInstance(metadata):
12241224
return JavaMLReader(self.cls).load(path) # type: ignore[arg-type]
12251225
else:
12261226
metadata, estimator, evaluator, estimatorParamMaps = _ValidatorSharedReadWrite.load(
1227-
path, self.sc, metadata
1227+
path, self.sparkSession, metadata
12281228
)
12291229
bestModelPath = os.path.join(path, "bestModel")
1230-
bestModel: Model = DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
1230+
bestModel: Model = DefaultParamsReader.loadParamsInstance(
1231+
bestModelPath, self.sparkSession
1232+
)
12311233
validationMetrics = metadata["validationMetrics"]
12321234
persistSubModels = ("persistSubModels" in metadata) and metadata["persistSubModels"]
12331235

@@ -1236,7 +1238,7 @@ def load(self, path: str) -> "TrainValidationSplitModel":
12361238
for paramIndex in range(len(estimatorParamMaps)):
12371239
modelPath = os.path.join(path, "subModels", f"{paramIndex}")
12381240
subModels[paramIndex] = DefaultParamsReader.loadParamsInstance(
1239-
modelPath, self.sc
1241+
modelPath, self.sparkSession
12401242
)
12411243
else:
12421244
subModels = None
@@ -1273,7 +1275,9 @@ def saveImpl(self, path: str) -> None:
12731275
"validationMetrics": instance.validationMetrics,
12741276
"persistSubModels": persistSubModels,
12751277
}
1276-
_ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, extraMetadata=extraMetadata)
1278+
_ValidatorSharedReadWrite.saveImpl(
1279+
path, instance, self.sparkSession, extraMetadata=extraMetadata
1280+
)
12771281
bestModelPath = os.path.join(path, "bestModel")
12781282
cast(MLWritable, instance.bestModel).save(bestModelPath)
12791283
if persistSubModels:
@@ -1473,7 +1477,7 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
14731477
subModels = [None for i in range(numModels)]
14741478

14751479
tasks = map(
1476-
inheritable_thread_target,
1480+
inheritable_thread_target(dataset.sparkSession),
14771481
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam),
14781482
)
14791483
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
@@ -1529,6 +1533,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplit":
15291533
return newTVS
15301534

15311535
@since("2.3.0")
1536+
@try_remote_write
15321537
def write(self) -> MLWriter:
15331538
"""Returns an MLWriter instance for this ML instance."""
15341539
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1537,6 +1542,7 @@ def write(self) -> MLWriter:
15371542

15381543
@classmethod
15391544
@since("2.3.0")
1545+
@try_remote_read
15401546
def read(cls) -> TrainValidationSplitReader:
15411547
"""Returns an MLReader instance for this class."""
15421548
return TrainValidationSplitReader(cls)
@@ -1649,6 +1655,7 @@ def copy(self, extra: Optional["ParamMap"] = None) -> "TrainValidationSplitModel
16491655
)
16501656

16511657
@since("2.3.0")
1658+
@try_remote_write
16521659
def write(self) -> MLWriter:
16531660
"""Returns an MLWriter instance for this ML instance."""
16541661
if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1657,6 +1664,7 @@ def write(self) -> MLWriter:
16571664

16581665
@classmethod
16591666
@since("2.3.0")
1667+
@try_remote_read
16601668
def read(cls) -> TrainValidationSplitModelReader:
16611669
"""Returns an MLReader instance for this class."""
16621670
return TrainValidationSplitModelReader(cls)

0 commit comments

Comments
 (0)