|
16 | 16 | #
|
17 | 17 |
|
18 | 18 | import sys
|
| 19 | +import os |
19 | 20 |
|
20 | 21 | if sys.version > '3':
|
21 | 22 | basestring = str
|
22 | 23 |
|
23 | 24 | from pyspark import since, keyword_only, SparkContext
|
24 | 25 | from pyspark.ml.base import Estimator, Model, Transformer
|
25 | 26 | from pyspark.ml.param import Param, Params
|
26 |
| -from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable |
| 27 | +from pyspark.ml.util import * |
27 | 28 | from pyspark.ml.wrapper import JavaParams
|
28 | 29 | from pyspark.ml.common import inherit_doc
|
29 | 30 |
|
@@ -130,13 +131,16 @@ def copy(self, extra=None):
|
130 | 131 | @since("2.0.0")
|
131 | 132 | def write(self):
|
132 | 133 | """Returns an MLWriter instance for this ML instance."""
|
133 |
| - return JavaMLWriter(self) |
| 134 | + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.getStages()) |
| 135 | + if allStagesAreJava: |
| 136 | + return JavaMLWriter(self) |
| 137 | + return PipelineWriter(self) |
134 | 138 |
|
135 | 139 | @classmethod
|
136 | 140 | @since("2.0.0")
|
137 | 141 | def read(cls):
|
138 | 142 | """Returns an MLReader instance for this class."""
|
139 |
| - return JavaMLReader(cls) |
| 143 | + return PipelineReader(cls) |
140 | 144 |
|
141 | 145 | @classmethod
|
142 | 146 | def _from_java(cls, java_stage):
|
@@ -171,6 +175,76 @@ def _to_java(self):
|
171 | 175 | return _java_obj
|
172 | 176 |
|
173 | 177 |
|
| 178 | +@inherit_doc |
| 179 | +class PipelineWriter(MLWriter): |
| 180 | + """ |
| 181 | + (Private) Specialization of :py:class:`MLWriter` for :py:class:`Pipeline` types |
| 182 | + """ |
| 183 | + |
| 184 | + def __init__(self, instance): |
| 185 | + super(PipelineWriter, self).__init__() |
| 186 | + self.instance = instance |
| 187 | + |
| 188 | + def saveImpl(self, path): |
| 189 | + stages = self.instance.getStages() |
| 190 | + PipelineSharedReadWrite.validateStages(stages) |
| 191 | + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) |
| 192 | + |
| 193 | + |
| 194 | +@inherit_doc |
| 195 | +class PipelineReader(MLReader): |
| 196 | + """ |
| 197 | + (Private) Specialization of :py:class:`MLReader` for :py:class:`Pipeline` types |
| 198 | + """ |
| 199 | + |
| 200 | + def __init__(self, cls): |
| 201 | + super(PipelineReader, self).__init__() |
| 202 | + self.cls = cls |
| 203 | + |
| 204 | + def load(self, path): |
| 205 | + metadata = DefaultParamsReader.loadMetadata(path, self.sc) |
| 206 | + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': |
| 207 | + return JavaMLReader(self.cls).load(path) |
| 208 | + else: |
| 209 | + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) |
| 210 | + return Pipeline(stages=stages)._resetUid(uid) |
| 211 | + |
| 212 | + |
| 213 | +@inherit_doc |
| 214 | +class PipelineModelWriter(MLWriter): |
| 215 | + """ |
| 216 | + (Private) Specialization of :py:class:`MLWriter` for :py:class:`PipelineModel` types |
| 217 | + """ |
| 218 | + |
| 219 | + def __init__(self, instance): |
| 220 | + super(PipelineModelWriter, self).__init__() |
| 221 | + self.instance = instance |
| 222 | + |
| 223 | + def saveImpl(self, path): |
| 224 | + stages = self.instance.stages |
| 225 | + PipelineSharedReadWrite.validateStages(stages) |
| 226 | + PipelineSharedReadWrite.saveImpl(self.instance, stages, self.sc, path) |
| 227 | + |
| 228 | + |
| 229 | +@inherit_doc |
| 230 | +class PipelineModelReader(MLReader): |
| 231 | + """ |
| 232 | + (Private) Specialization of :py:class:`MLReader` for :py:class:`PipelineModel` types |
| 233 | + """ |
| 234 | + |
| 235 | + def __init__(self, cls): |
| 236 | + super(PipelineModelReader, self).__init__() |
| 237 | + self.cls = cls |
| 238 | + |
| 239 | + def load(self, path): |
| 240 | + metadata = DefaultParamsReader.loadMetadata(path, self.sc) |
| 241 | + if 'language' not in metadata['paramMap'] or metadata['paramMap']['language'] != 'Python': |
| 242 | + return JavaMLReader(self.cls).load(path) |
| 243 | + else: |
| 244 | + uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path) |
| 245 | + return PipelineModel(stages=stages)._resetUid(uid) |
| 246 | + |
| 247 | + |
174 | 248 | @inherit_doc
|
175 | 249 | class PipelineModel(Model, MLReadable, MLWritable):
|
176 | 250 | """
|
@@ -204,13 +278,16 @@ def copy(self, extra=None):
|
204 | 278 | @since("2.0.0")
|
205 | 279 | def write(self):
|
206 | 280 | """Returns an MLWriter instance for this ML instance."""
|
207 |
| - return JavaMLWriter(self) |
| 281 | + allStagesAreJava = PipelineSharedReadWrite.checkStagesForJava(self.stages) |
| 282 | + if allStagesAreJava: |
| 283 | + return JavaMLWriter(self) |
| 284 | + return PipelineModelWriter(self) |
208 | 285 |
|
209 | 286 | @classmethod
|
210 | 287 | @since("2.0.0")
|
211 | 288 | def read(cls):
|
212 | 289 | """Returns an MLReader instance for this class."""
|
213 |
| - return JavaMLReader(cls) |
| 290 | + return PipelineModelReader(cls) |
214 | 291 |
|
215 | 292 | @classmethod
|
216 | 293 | def _from_java(cls, java_stage):
|
@@ -242,3 +319,72 @@ def _to_java(self):
|
242 | 319 | JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
|
243 | 320 |
|
244 | 321 | return _java_obj
|
| 322 | + |
| 323 | + |
| 324 | +@inherit_doc |
| 325 | +class PipelineSharedReadWrite(): |
| 326 | + """ |
| 327 | + .. note:: DeveloperApi |
| 328 | +
|
| 329 | + Functions for :py:class:`MLReader` and :py:class:`MLWriter` shared between |
| 330 | + :py:class:`Pipeline` and :py:class:`PipelineModel` |
| 331 | +
|
| 332 | + .. versionadded:: 2.3.0 |
| 333 | + """ |
| 334 | + |
| 335 | + @staticmethod |
| 336 | + def checkStagesForJava(stages): |
| 337 | + return all(isinstance(stage, JavaMLWritable) for stage in stages) |
| 338 | + |
| 339 | + @staticmethod |
| 340 | + def validateStages(stages): |
| 341 | + """ |
| 342 | + Check that all stages are Writable |
| 343 | + """ |
| 344 | + for stage in stages: |
| 345 | + if not isinstance(stage, MLWritable): |
| 346 | + raise ValueError("Pipeline write will fail on this pipeline " + |
| 347 | + "because stage %s of type %s is not MLWritable", |
| 348 | + stage.uid, type(stage)) |
| 349 | + |
| 350 | + @staticmethod |
| 351 | + def saveImpl(instance, stages, sc, path): |
| 352 | + """ |
| 353 | + Save metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` |
| 354 | + - save metadata to path/metadata |
| 355 | + - save stages to stages/IDX_UID |
| 356 | + """ |
| 357 | + stageUids = [stage.uid for stage in stages] |
| 358 | + jsonParams = {'stageUids': stageUids, 'language': 'Python'} |
| 359 | + DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap=jsonParams) |
| 360 | + stagesDir = os.path.join(path, "stages") |
| 361 | + for index, stage in enumerate(stages): |
| 362 | + stage.write().save(PipelineSharedReadWrite |
| 363 | + .getStagePath(stage.uid, index, len(stages), stagesDir)) |
| 364 | + |
| 365 | + @staticmethod |
| 366 | + def load(metadata, sc, path): |
| 367 | + """ |
| 368 | + Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` |
| 369 | +
|
| 370 | + :return: (UID, list of stages) |
| 371 | + """ |
| 372 | + stagesDir = os.path.join(path, "stages") |
| 373 | + stageUids = metadata['paramMap']['stageUids'] |
| 374 | + stages = [] |
| 375 | + for index, stageUid in enumerate(stageUids): |
| 376 | + stagePath = \ |
| 377 | + PipelineSharedReadWrite.getStagePath(stageUid, index, len(stageUids), stagesDir) |
| 378 | + stage = DefaultParamsReader.loadParamsInstance(stagePath, sc) |
| 379 | + stages.append(stage) |
| 380 | + return (metadata['uid'], stages) |
| 381 | + |
| 382 | + @staticmethod |
| 383 | + def getStagePath(stageUid, stageIdx, numStages, stagesDir): |
| 384 | + """ |
| 385 | + Get path for saving the given stage. |
| 386 | + """ |
| 387 | + stageIdxDigits = len(str(numStages)) |
| 388 | + stageDir = str(stageIdx).zfill(stageIdxDigits) + "_" + stageUid |
| 389 | + stagePath = os.path.join(stagesDir, stageDir) |
| 390 | + return stagePath |
0 commit comments