Skip to content

Commit 35db3b9

Browse files
ajaysaini725jkbradley
authored andcommitted
[SPARK-17025][ML][PYTHON] Persistence for Pipelines with Python-only Stages
## What changes were proposed in this pull request? Implemented a Python-only persistence framework for pipelines containing stages that cannot be saved using Java. ## How was this patch tested? Created a custom Python-only UnaryTransformer, included it in a Pipeline, and saved/loaded the pipeline. The loaded pipeline was compared against the original using _compare_pipelines() in tests.py. Author: Ajay Saini <[email protected]> Closes apache#18888 from ajaysaini725/PythonPipelines.
1 parent b0bdfce commit 35db3b9

File tree

2 files changed

+183
-8
lines changed

2 files changed

+183
-8
lines changed

python/pyspark/ml/pipeline.py

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
#
1717

1818
import sys
19+
import os
1920

2021
if sys.version > '3':
2122
basestring = str
2223

2324
from pyspark import since, keyword_only, SparkContext
2425
from pyspark.ml.base import Estimator, Model, Transformer
2526
from pyspark.ml.param import Param, Params
26-
from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
27+
from pyspark.ml.util import *
2728
from pyspark.ml.wrapper import JavaParams
2829
from pyspark.ml.common import inherit_doc
2930

@@ -130,13 +131,16 @@ def copy(self, extra=None):
130131
@since("2.0.0")
131132
def write(self):
132133
"""Returns an MLWriter instance for this ML instance."""
133-
return JavaMLWriter(self)
134+
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages())
135+
if allStagesAreJava:
136+
return JavaMLWriter(self)
137+
return PipelineWriter(self)
134138

135139
@classmethod
136140
@since("2.0.0")
137141
def read(cls):
138142
"""Returns an MLReader instance for this class."""
139-
return JavaMLReader(cls)
143+
return PipelineReader(cls)
140144

141145
@classmethod
142146
def _from_java(cls, java_stage):
@@ -171,6 +175,76 @@ def _to_java(self):
171175
return _java_obj
172176

173177

178+
@inherit_doc
179+
class PipelineWriter(MLWriter):
180+
"""
181+
(Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types
182+
"""
183+
184+
def __init__(self, instance):
185+
super(PipelineWriter, self).__init__()
186+
self.instance = instance
187+
188+
def saveImpl(self, path):
189+
stages = self.instance.getStages()
190+
PipelineSharedReadWrite.validateStages(stages)
191+
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
192+
193+
194+
@inherit_doc
195+
class PipelineReader(MLReader):
196+
"""
197+
(Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types
198+
"""
199+
200+
def __init__(self, cls):
201+
super(PipelineReader, self).__init__()
202+
self.cls = cls
203+
204+
def load(self, path):
205+
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
206+
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
207+
return JavaMLReader(self.cls).load(path)
208+
else:
209+
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
210+
return Pipeline(stages=stages)._resetUid(uid)
211+
212+
213+
@inherit_doc
214+
class PipelineModelWriter(MLWriter):
215+
"""
216+
(Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types
217+
"""
218+
219+
def __init__(self, instance):
220+
super(PipelineModelWriter, self).__init__()
221+
self.instance = instance
222+
223+
def saveImpl(self, path):
224+
stages = self.instance.stages
225+
PipelineSharedReadWrite.validateStages(stages)
226+
PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path)
227+
228+
229+
@inherit_doc
230+
class PipelineModelReader(MLReader):
231+
"""
232+
(Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types
233+
"""
234+
235+
def __init__(self, cls):
236+
super(PipelineModelReader, self).__init__()
237+
self.cls = cls
238+
239+
def load(self, path):
240+
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
241+
if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python':
242+
return JavaMLReader(self.cls).load(path)
243+
else:
244+
uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
245+
return PipelineModel(stages=stages)._resetUid(uid)
246+
247+
174248
@inherit_doc
175249
class PipelineModel(Model, MLReadable, MLWritable):
176250
"""
@@ -204,13 +278,16 @@ def copy(self, extra=None):
204278
@since("2.0.0")
205279
def write(self):
206280
"""Returns an MLWriter instance for this ML instance."""
207-
return JavaMLWriter(self)
281+
allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages)
282+
if allStagesAreJava:
283+
return JavaMLWriter(self)
284+
return PipelineModelWriter(self)
208285

209286
@classmethod
210287
@since("2.0.0")
211288
def read(cls):
212289
"""Returns an MLReader instance for this class."""
213-
return JavaMLReader(cls)
290+
return PipelineModelReader(cls)
214291

215292
@classmethod
216293
def _from_java(cls, java_stage):
@@ -242,3 +319,72 @@ def _to_java(self):
242319
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
243320

244321
return _java_obj
322+
323+
324+
@inherit_doc
325+
class PipelineSharedReadWrite():
326+
"""
327+
.. note:: DeveloperApi
328+
329+
Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between
330+
:py:class:`Pipeline` and :py:class:`PipelineModel`
331+
332+
.. versionadded:: 2.3.0
333+
"""
334+
335+
@staticmethod
336+
def checkStagesForJava(stages):
337+
return all(isinstance(stage, JavaMLWritable) for stage in stages)
338+
339+
@staticmethod
340+
def validateStages(stages):
341+
"""
342+
Check that all stages are Writable
343+
"""
344+
for stage in stages:
345+
if not isinstance(stage, MLWritable):
346+
raise ValueError("Pipeline write will fail on this pipeline " +
347+
"because stage %s of type %s is not MLWritable",
348+
stage.uid, type(stage))
349+
350+
@staticmethod
351+
def saveImpl(instance, stages, sc, path):
352+
"""
353+
Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
354+
- save metadata to path/metadata
355+
- save stages to stages/IDX_UID
356+
"""
357+
stageUids = [stage.uid for stage in stages]
358+
jsonParams = {'stageUids': stageUids, 'language': 'Python'}
359+
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams)
360+
stagesDir = os.path.join(path, "stages")
361+
for index, stage in enumerate(stages):
362+
stage.write().save(PipelineSharedReadWrite
363+
.getStagePath(stage.uid, index, len(stages), stagesDir))
364+
365+
@staticmethod
366+
def load(metadata, sc, path):
367+
"""
368+
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
369+
370+
:return: (UID, list of stages)
371+
"""
372+
stagesDir = os.path.join(path, "stages")
373+
stageUids = metadata['paramMap']['stageUids']
374+
stages = []
375+
for index, stageUid in enumerate(stageUids):
376+
stagePath = \
377+
PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir)
378+
stage = DefaultParamsReader.loadParamsInstance(stagePath, sc)
379+
stages.append(stage)
380+
return (metadata['uid'], stages)
381+
382+
@staticmethod
383+
def getStagePath(stageUid, stageIdx, numStages, stagesDir):
384+
"""
385+
Get path for saving the given stage.
386+
"""
387+
stageIdxDigits = len(str(numStages))
388+
stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid
389+
stagePath = os.path.join(stagesDir, stageDir)
390+
return stagePath

python/pyspark/ml/tests.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _transform(self, dataset):
123123
return dataset
124124

125125

126-
class MockUnaryTransformer(UnaryTransformer):
126+
class MockUnaryTransformer(UnaryTransformer, DefaultParamsReadable, DefaultParamsWritable):
127127

128128
shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
129129
"data in a DataFrame",
@@ -150,7 +150,7 @@ def outputDataType(self):
150150
def validateInputType(self, inputType):
151151
if inputType != DoubleType():
152152
raise TypeError("Bad input type: {}. ".format(inputType) +
153-
"Requires Integer.")
153+
"Requires Double.")
154154

155155

156156
class MockEstimator(Estimator, HasFake):
@@ -1063,7 +1063,7 @@ def _compare_pipelines(self, m1, m2):
10631063
"""
10641064
self.assertEqual(m1.uid, m2.uid)
10651065
self.assertEqual(type(m1), type(m2))
1066-
if isinstance(m1, JavaParams):
1066+
if isinstance(m1, JavaParams) or isinstance(m1, Transformer):
10671067
self.assertEqual(len(m1.params), len(m2.params))
10681068
for p in m1.params:
10691069
self._compare_params(m1, m2, p)
@@ -1142,6 +1142,35 @@ def test_nested_pipeline_persistence(self):
11421142
except OSError:
11431143
pass
11441144

1145+
def test_python_transformer_pipeline_persistence(self):
1146+
"""
1147+
Pipeline[MockUnaryTransformer, Binarizer]
1148+
"""
1149+
temp_path = tempfile.mkdtemp()
1150+
1151+
try:
1152+
df = self.spark.range(0, 10).toDF('input')
1153+
tf = MockUnaryTransformer(shiftVal=2)\
1154+
.setInputCol("input").setOutputCol("shiftedInput")
1155+
tf2 = Binarizer(threshold=6, inputCol="shiftedInput", outputCol="binarized")
1156+
pl = Pipeline(stages=[tf, tf2])
1157+
model = pl.fit(df)
1158+
1159+
pipeline_path = temp_path + "/pipeline"
1160+
pl.save(pipeline_path)
1161+
loaded_pipeline = Pipeline.load(pipeline_path)
1162+
self._compare_pipelines(pl, loaded_pipeline)
1163+
1164+
model_path = temp_path + "/pipeline-model"
1165+
model.save(model_path)
1166+
loaded_model = PipelineModel.load(model_path)
1167+
self._compare_pipelines(model, loaded_model)
1168+
finally:
1169+
try:
1170+
rmtree(temp_path)
1171+
except OSError:
1172+
pass
1173+
11451174
def test_onevsrest(self):
11461175
temp_path = tempfile.mkdtemp()
11471176
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),

0 commit comments

Comments
 (0)