Skip to content

Commit e4c2903

Browse files
committed
Python unit tests check read pipeline stages
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 13d04e6 commit e4c2903

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pyrasterframes/src/main/python/tests/ExploderTests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,4 @@ def test_tile_exploder_read_write(self):
6868

6969
read_pipe = PipelineModel.load(path)
7070
self.assertEqual(len(read_pipe.stages), 2)
71+
self.assertTrue(isinstance(read_pipe.stages[0], TileExploder))

pyrasterframes/src/main/python/tests/NoDataFilterTests.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,15 @@ def test_no_data_filter_read_write(self):
3737
df = self.spark.read.raster(self.img_uri) \
3838
.select(rf_tile_mean('proj_raster').alias('mean'))
3939

40-
ndf = NoDataFilter().setInputCols(['mean'])
41-
assembler = VectorAssembler().setInputCols(['mean'])
40+
input_cols = ['mean']
41+
ndf = NoDataFilter().setInputCols(input_cols)
42+
assembler = VectorAssembler().setInputCols(input_cols)
4243

4344
pipe = Pipeline().setStages([ndf, assembler])
4445

4546
pipe.fit(df).write().overwrite().save(path)
4647

4748
read_pipe = PipelineModel.load(path)
4849
self.assertEqual(len(read_pipe.stages), 2)
50+
actual_stages_ndf = read_pipe.stages[0].getInputCols()
51+
self.assertEqual(actual_stages_ndf, input_cols)

0 commit comments

Comments
 (0)