Skip to content

Commit 6bf41f2

Browse files
authored
Merge pull request #251 from s22s/feature/unit-test-163
Breaking unit test in python for issue 163 tile exploder with Project…
2 parents da0d0f4 + 9357c56 commit 6bf41f2

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# This software is licensed under the Apache 2 license, quoted below.
3+
#
4+
# Copyright 2019 Astraea, Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
# use this file except in compliance with the License. You may obtain a copy of
8+
# the License at
9+
#
10+
# [http://www.apache.org/licenses/LICENSE-2.0]
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations under
16+
# the License.
17+
#
18+
# SPDX-License-Identifier: Apache-2.0
19+
#
20+
21+
from . import TestEnvironment
22+
23+
from pyrasterframes.rasterfunctions import *
24+
from pyrasterframes.rf_types import *
25+
from pyrasterframes import TileExploder
26+
27+
from pyspark.ml.feature import VectorAssembler
28+
from pyspark.ml import Pipeline
29+
from pyspark.sql.functions import *
30+
31+
import unittest
32+
33+
34+
class ExploderTests(TestEnvironment):
35+
36+
@unittest.skip("See issue https://github.com/locationtech/rasterframes/issues/163")
37+
def test_tile_exploder_pipeline_for_prt(self):
38+
# NB the tile is a Projected Raster Tile
39+
df = self.spark.read.raster(self.img_uri)
40+
t_col = 'proj_raster'
41+
self.assertTrue(t_col in df.columns)
42+
43+
assembler = VectorAssembler().setInputCols([t_col])
44+
pipe = Pipeline().setStages([TileExploder(), assembler])
45+
pipe_model = pipe.fit(df)
46+
tranformed_df = pipe_model.transform(df)
47+
self.assertTrue(tranformed_df.count() > df.count())
48+
49+
def test_tile_exploder_pipeline_for_tile(self):
50+
t_col = 'tile'
51+
df = self.spark.read.raster(self.img_uri) \
52+
.withColumn(t_col, rf_tile('proj_raster')) \
53+
.drop('proj_raster')
54+
55+
assembler = VectorAssembler().setInputCols([t_col])
56+
pipe = Pipeline().setStages([TileExploder(), assembler])
57+
pipe_model = pipe.fit(df)
58+
tranformed_df = pipe_model.transform(df)
59+
self.assertTrue(tranformed_df.count() > df.count())

0 commit comments

Comments
 (0)