Skip to content

Commit 47b42a9

Browse files
authored
Merge pull request #426 from s22s/fix/425-ml-pipeline
Fix #425 ML Pipeline save/load
2 parents b2aa335 + 6645246 commit 47b42a9

File tree

3 files changed

+68
-4
lines changed

3 files changed

+68
-4
lines changed

pyrasterframes/src/main/python/pyrasterframes/rf_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class here provides the PyRasterFrames entry point.
3131

3232
from pyspark.ml.param.shared import HasInputCols
3333
from pyspark.ml.wrapper import JavaTransformer
34-
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
34+
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
3535

3636
from pyrasterframes.rf_context import RFContext
3737

@@ -462,7 +462,7 @@ def deserialize(self, datum):
462462
Tile.__UDT__ = TileUDT()
463463

464464

465-
class TileExploder(JavaTransformer, JavaMLReadable, JavaMLWritable):
465+
class TileExploder(JavaTransformer, DefaultParamsReadable, DefaultParamsWritable):
466466
"""
467467
Python wrapper for TileExploder.scala
468468
"""
@@ -472,7 +472,7 @@ def __init__(self):
472472
self._java_obj = self._new_java_obj("org.locationtech.rasterframes.ml.TileExploder", self.uid)
473473

474474

475-
class NoDataFilter(JavaTransformer, HasInputCols, JavaMLReadable, JavaMLWritable):
475+
class NoDataFilter(JavaTransformer, HasInputCols, DefaultParamsReadable, DefaultParamsWritable):
476476
"""
477477
Python wrapper for NoDataFilter.scala
478478
"""

pyrasterframes/src/main/python/tests/ExploderTests.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pyrasterframes import TileExploder
2626

2727
from pyspark.ml.feature import VectorAssembler
28-
from pyspark.ml import Pipeline
28+
from pyspark.ml import Pipeline, PipelineModel
2929
from pyspark.sql.functions import *
3030

3131
import unittest
@@ -56,3 +56,16 @@ def test_tile_exploder_pipeline_for_tile(self):
5656
pipe_model = pipe.fit(df)
5757
tranformed_df = pipe_model.transform(df)
5858
self.assertTrue(tranformed_df.count() > df.count())
59+
60+
def test_tile_exploder_read_write(self):
61+
path = 'test_tile_exploder_read_write.pipe'
62+
df = self.spark.read.raster(self.img_uri)
63+
64+
assembler = VectorAssembler().setInputCols(['proj_raster'])
65+
pipe = Pipeline().setStages([TileExploder(), assembler])
66+
67+
pipe.fit(df).write().overwrite().save(path)
68+
69+
read_pipe = PipelineModel.load(path)
70+
self.assertEqual(len(read_pipe.stages), 2)
71+
self.assertTrue(isinstance(read_pipe.stages[0], TileExploder))
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# This software is licensed under the Apache 2 license, quoted below.
3+
#
4+
# Copyright 2019 Astraea, Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
# use this file except in compliance with the License. You may obtain a copy of
8+
# the License at
9+
#
10+
# [http://www.apache.org/licenses/LICENSE-2.0]
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations under
16+
# the License.
17+
#
18+
# SPDX-License-Identifier: Apache-2.0
19+
#
20+
21+
from . import TestEnvironment
22+
23+
from pyrasterframes.rasterfunctions import *
24+
from pyrasterframes.rf_types import *
25+
26+
from pyspark.ml.feature import VectorAssembler
27+
from pyspark.ml import Pipeline, PipelineModel
28+
from pyspark.sql.functions import *
29+
30+
import unittest
31+
32+
33+
class ExploderTests(TestEnvironment):
34+
35+
def test_no_data_filter_read_write(self):
36+
path = 'test_no_data_filter_read_write.pipe'
37+
df = self.spark.read.raster(self.img_uri) \
38+
.select(rf_tile_mean('proj_raster').alias('mean'))
39+
40+
input_cols = ['mean']
41+
ndf = NoDataFilter().setInputCols(input_cols)
42+
assembler = VectorAssembler().setInputCols(input_cols)
43+
44+
pipe = Pipeline().setStages([ndf, assembler])
45+
46+
pipe.fit(df).write().overwrite().save(path)
47+
48+
read_pipe = PipelineModel.load(path)
49+
self.assertEqual(len(read_pipe.stages), 2)
50+
actual_stages_ndf = read_pipe.stages[0].getInputCols()
51+
self.assertEqual(actual_stages_ndf, input_cols)

0 commit comments

Comments
 (0)