Skip to content

Commit 2c6df24

Browse files
committed
Separated out RasterFunctionsTests.
1 parent dd496ab commit 2c6df24

File tree

3 files changed

+204
-178
lines changed

3 files changed

+204
-178
lines changed

datasource/src/main/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSource.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ import _root_.geotrellis.raster._
2828
import _root_.geotrellis.raster.io.geotiff.compression._
2929
import _root_.geotrellis.raster.io.geotiff.tags.codes.ColorSpace
3030
import _root_.geotrellis.raster.io.geotiff.{GeoTiffOptions, MultibandGeoTiff, Tags, Tiled}
31-
<<<<<<< HEAD
32-
=======
3331
import _root_.geotrellis.spark._
34-
>>>>>>> develop
3532
import com.typesafe.scalalogging.LazyLogging
3633
import org.apache.spark.sql._
3734
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 0 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -28,181 +28,6 @@
2828
from . import TestEnvironment
2929

3030

31-
class RasterFunctions(TestEnvironment):
32-
33-
def setUp(self):
34-
self.create_layer()
35-
36-
def test_setup(self):
37-
self.assertEqual(self.spark.sparkContext.getConf().get("spark.serializer"),
38-
"org.apache.spark.serializer.KryoSerializer")
39-
40-
def test_identify_columns(self):
41-
cols = self.rf.tile_columns()
42-
self.assertEqual(len(cols), 1, '`tileColumns` did not find the proper number of columns.')
43-
print("Tile columns: ", cols)
44-
col = self.rf.spatial_key_column()
45-
self.assertIsInstance(col, Column, '`spatialKeyColumn` was not found')
46-
print("Spatial key column: ", col)
47-
col = self.rf.temporal_key_column()
48-
self.assertIsNone(col, '`temporalKeyColumn` should be `None`')
49-
print("Temporal key column: ", col)
50-
51-
def test_tile_creation(self):
52-
base = self.spark.createDataFrame([1, 2, 3, 4], 'integer')
53-
tiles = base.select(rf_make_constant_tile(3, 3, 3, "int32"), rf_make_zeros_tile(3, 3, "int32"),
54-
rf_make_ones_tile(3, 3, "int32"))
55-
tiles.show()
56-
self.assertEqual(tiles.count(), 4)
57-
58-
def test_multi_column_operations(self):
59-
df1 = self.rf.withColumnRenamed('tile', 't1').as_layer()
60-
df2 = self.rf.withColumnRenamed('tile', 't2').as_layer()
61-
df3 = df1.spatial_join(df2).as_layer()
62-
df3 = df3.withColumn('norm_diff', rf_normalized_difference('t1', 't2'))
63-
# df3.printSchema()
64-
65-
aggs = df3.agg(
66-
rf_agg_mean('norm_diff'),
67-
)
68-
aggs.show()
69-
row = aggs.first()
70-
71-
self.assertTrue(self.rounded_compare(row['rf_agg_mean(norm_diff)'], 0))
72-
73-
def test_general(self):
74-
meta = self.rf.tile_layer_metadata()
75-
self.assertIsNotNone(meta['bounds'])
76-
df = self.rf.withColumn('dims', rf_dimensions('tile')) \
77-
.withColumn('type', rf_cell_type('tile')) \
78-
.withColumn('dCells', rf_data_cells('tile')) \
79-
.withColumn('ndCells', rf_no_data_cells('tile')) \
80-
.withColumn('min', rf_tile_min('tile')) \
81-
.withColumn('max', rf_tile_max('tile')) \
82-
.withColumn('mean', rf_tile_mean('tile')) \
83-
.withColumn('sum', rf_tile_sum('tile')) \
84-
.withColumn('stats', rf_tile_stats('tile')) \
85-
.withColumn('extent', st_extent('geometry')) \
86-
.withColumn('extent_geom1', st_geometry('extent')) \
87-
.withColumn('ascii', rf_render_ascii('tile')) \
88-
.withColumn('log', rf_log('tile')) \
89-
.withColumn('exp', rf_exp('tile')) \
90-
.withColumn('expm1', rf_expm1('tile')) \
91-
.withColumn('round', rf_round('tile')) \
92-
.withColumn('abs', rf_abs('tile'))
93-
94-
df.first()
95-
96-
def test_agg_mean(self):
97-
mean = self.rf.agg(rf_agg_mean('tile')).first()['rf_agg_mean(tile)']
98-
self.assertTrue(self.rounded_compare(mean, 10160))
99-
100-
def test_aggregations(self):
101-
aggs = self.rf.agg(
102-
rf_agg_data_cells('tile'),
103-
rf_agg_no_data_cells('tile'),
104-
rf_agg_stats('tile'),
105-
rf_agg_approx_histogram('tile')
106-
)
107-
row = aggs.first()
108-
109-
# print(row['rf_agg_data_cells(tile)'])
110-
self.assertEqual(row['rf_agg_data_cells(tile)'], 387000)
111-
self.assertEqual(row['rf_agg_no_data_cells(tile)'], 1000)
112-
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])
113-
114-
def test_sql(self):
115-
self.rf.createOrReplaceTempView("rf_test_sql")
116-
117-
self.spark.sql("""SELECT tile,
118-
rf_local_add(tile, 1) AS and_one,
119-
rf_local_subtract(tile, 1) AS less_one,
120-
rf_local_multiply(tile, 2) AS times_two,
121-
rf_local_divide(tile, 2) AS over_two
122-
FROM rf_test_sql""").createOrReplaceTempView('rf_test_sql_1')
123-
124-
statsRow = self.spark.sql("""
125-
SELECT rf_tile_mean(tile) as base,
126-
rf_tile_mean(and_one) as plus_one,
127-
rf_tile_mean(less_one) as minus_one,
128-
rf_tile_mean(times_two) as double,
129-
rf_tile_mean(over_two) as half
130-
FROM rf_test_sql_1
131-
""").first()
132-
133-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.plus_one - 1))
134-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.minus_one + 1))
135-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.double / 2))
136-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.half * 2))
137-
138-
def test_explode(self):
139-
import pyspark.sql.functions as F
140-
self.rf.select('spatial_key', rf_explode_tiles('tile')).show()
141-
# +-----------+------------+---------+-------+
142-
# |spatial_key|column_index|row_index|tile |
143-
# +-----------+------------+---------+-------+
144-
# |[2,1] |4 |0 |10150.0|
145-
cell = self.rf.select(self.rf.spatial_key_column(), rf_explode_tiles(self.rf.tile)) \
146-
.where(F.col("spatial_key.col") == 2) \
147-
.where(F.col("spatial_key.row") == 1) \
148-
.where(F.col("column_index") == 4) \
149-
.where(F.col("row_index") == 0) \
150-
.select(F.col("tile")) \
151-
.collect()[0][0]
152-
self.assertEqual(cell, 10150.0)
153-
154-
# Test the sample version
155-
frac = 0.01
156-
sample_count = self.rf.select(rf_explode_tiles_sample(frac, 1872, 'tile')).count()
157-
print('Sample count is {}'.format(sample_count))
158-
self.assertTrue(sample_count > 0)
159-
self.assertTrue(sample_count < (frac * 1.1) * 387000) # give some wiggle room
160-
161-
def test_mask_by_value(self):
162-
from pyspark.sql.functions import lit
163-
164-
# create an artificial mask for values > 25000; masking value will be 4
165-
mask_value = 4
166-
167-
rf1 = self.rf.select(self.rf.tile,
168-
rf_local_multiply(
169-
rf_convert_cell_type(
170-
rf_local_greater_int(self.rf.tile, 25000),
171-
"uint8"),
172-
lit(mask_value)).alias('mask'))
173-
rf2 = rf1.select(rf1.tile, rf_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
174-
result = rf2.agg(rf_agg_no_data_cells(rf2.tile) < rf_agg_no_data_cells(rf2.masked)) \
175-
.collect()[0][0]
176-
self.assertTrue(result)
177-
178-
rf3 = rf1.select(rf1.tile, rf_inverse_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
179-
result = rf3.agg(rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked)) \
180-
.collect()[0][0]
181-
self.assertTrue(result)
182-
183-
def test_resample(self):
184-
from pyspark.sql.functions import lit
185-
result = self.rf.select(
186-
rf_tile_min(rf_local_equal(
187-
rf_resample(rf_resample(self.rf.tile, lit(2)), lit(0.5)),
188-
self.rf.tile))
189-
).collect()[0][0]
190-
191-
self.assertTrue(result == 1) # short hand for all values are true
192-
193-
def test_exists_for_all(self):
194-
df = self.rf.withColumn('should_exist', rf_make_ones_tile(5, 5, 'int8')) \
195-
.withColumn('should_not_exist', rf_make_zeros_tile(5, 5, 'int8'))
196-
197-
should_exist = df.select(rf_exists(df.should_exist).alias('se')).take(1)[0].se
198-
self.assertTrue(should_exist)
199-
200-
should_not_exist = df.select(rf_exists(df.should_not_exist).alias('se')).take(1)[0].se
201-
self.assertTrue(not should_not_exist)
202-
203-
self.assertTrue(df.select(rf_for_all(df.should_exist).alias('se')).take(1)[0].se)
204-
self.assertTrue(not df.select(rf_for_all(df.should_not_exist).alias('se')).take(1)[0].se)
205-
20631

20732
class CellTypeHandling(unittest.TestCase):
20833

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
#
2+
# This software is licensed under the Apache 2 license, quoted below.
3+
#
4+
# Copyright 2019 Astraea, Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
# use this file except in compliance with the License. You may obtain a copy of
8+
# the License at
9+
#
10+
# [http://www.apache.org/licenses/LICENSE-2.0]
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations under
16+
# the License.
17+
#
18+
# SPDX-License-Identifier: Apache-2.0
19+
#
20+
21+
import unittest
22+
23+
import numpy as np
24+
from pyrasterframes.rasterfunctions import *
25+
from pyrasterframes.rf_types import *
26+
from pyspark.sql import SQLContext
27+
from pyspark.sql.functions import *
28+
from . import TestEnvironment
29+
30+
31+
class RasterFunctions(TestEnvironment):
32+
33+
def setUp(self):
34+
self.create_layer()
35+
36+
def test_setup(self):
37+
self.assertEqual(self.spark.sparkContext.getConf().get("spark.serializer"),
38+
"org.apache.spark.serializer.KryoSerializer")
39+
40+
def test_identify_columns(self):
41+
cols = self.rf.tile_columns()
42+
self.assertEqual(len(cols), 1, '`tileColumns` did not find the proper number of columns.')
43+
print("Tile columns: ", cols)
44+
col = self.rf.spatial_key_column()
45+
self.assertIsInstance(col, Column, '`spatialKeyColumn` was not found')
46+
print("Spatial key column: ", col)
47+
col = self.rf.temporal_key_column()
48+
self.assertIsNone(col, '`temporalKeyColumn` should be `None`')
49+
print("Temporal key column: ", col)
50+
51+
def test_tile_creation(self):
52+
base = self.spark.createDataFrame([1, 2, 3, 4], 'integer')
53+
tiles = base.select(rf_make_constant_tile(3, 3, 3, "int32"), rf_make_zeros_tile(3, 3, "int32"),
54+
rf_make_ones_tile(3, 3, "int32"))
55+
tiles.show()
56+
self.assertEqual(tiles.count(), 4)
57+
58+
def test_multi_column_operations(self):
59+
df1 = self.rf.withColumnRenamed('tile', 't1').as_layer()
60+
df2 = self.rf.withColumnRenamed('tile', 't2').as_layer()
61+
df3 = df1.spatial_join(df2).as_layer()
62+
df3 = df3.withColumn('norm_diff', rf_normalized_difference('t1', 't2'))
63+
# df3.printSchema()
64+
65+
aggs = df3.agg(
66+
rf_agg_mean('norm_diff'),
67+
)
68+
aggs.show()
69+
row = aggs.first()
70+
71+
self.assertTrue(self.rounded_compare(row['rf_agg_mean(norm_diff)'], 0))
72+
73+
def test_general(self):
74+
meta = self.rf.tile_layer_metadata()
75+
self.assertIsNotNone(meta['bounds'])
76+
df = self.rf.withColumn('dims', rf_dimensions('tile')) \
77+
.withColumn('type', rf_cell_type('tile')) \
78+
.withColumn('dCells', rf_data_cells('tile')) \
79+
.withColumn('ndCells', rf_no_data_cells('tile')) \
80+
.withColumn('min', rf_tile_min('tile')) \
81+
.withColumn('max', rf_tile_max('tile')) \
82+
.withColumn('mean', rf_tile_mean('tile')) \
83+
.withColumn('sum', rf_tile_sum('tile')) \
84+
.withColumn('stats', rf_tile_stats('tile')) \
85+
.withColumn('extent', st_extent('geometry')) \
86+
.withColumn('extent_geom1', st_geometry('extent')) \
87+
.withColumn('ascii', rf_render_ascii('tile')) \
88+
.withColumn('log', rf_log('tile')) \
89+
.withColumn('exp', rf_exp('tile')) \
90+
.withColumn('expm1', rf_expm1('tile')) \
91+
.withColumn('round', rf_round('tile')) \
92+
.withColumn('abs', rf_abs('tile'))
93+
94+
df.first()
95+
96+
def test_agg_mean(self):
97+
mean = self.rf.agg(rf_agg_mean('tile')).first()['rf_agg_mean(tile)']
98+
self.assertTrue(self.rounded_compare(mean, 10160))
99+
100+
def test_aggregations(self):
101+
aggs = self.rf.agg(
102+
rf_agg_data_cells('tile'),
103+
rf_agg_no_data_cells('tile'),
104+
rf_agg_stats('tile'),
105+
rf_agg_approx_histogram('tile')
106+
)
107+
row = aggs.first()
108+
109+
# print(row['rf_agg_data_cells(tile)'])
110+
self.assertEqual(row['rf_agg_data_cells(tile)'], 387000)
111+
self.assertEqual(row['rf_agg_no_data_cells(tile)'], 1000)
112+
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])
113+
114+
def test_sql(self):
115+
self.rf.createOrReplaceTempView("rf_test_sql")
116+
117+
self.spark.sql("""SELECT tile,
118+
rf_local_add(tile, 1) AS and_one,
119+
rf_local_subtract(tile, 1) AS less_one,
120+
rf_local_multiply(tile, 2) AS times_two,
121+
rf_local_divide(tile, 2) AS over_two
122+
FROM rf_test_sql""").createOrReplaceTempView('rf_test_sql_1')
123+
124+
statsRow = self.spark.sql("""
125+
SELECT rf_tile_mean(tile) as base,
126+
rf_tile_mean(and_one) as plus_one,
127+
rf_tile_mean(less_one) as minus_one,
128+
rf_tile_mean(times_two) as double,
129+
rf_tile_mean(over_two) as half
130+
FROM rf_test_sql_1
131+
""").first()
132+
133+
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.plus_one - 1))
134+
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.minus_one + 1))
135+
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.double / 2))
136+
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.half * 2))
137+
138+
def test_explode(self):
139+
import pyspark.sql.functions as F
140+
self.rf.select('spatial_key', rf_explode_tiles('tile')).show()
141+
# +-----------+------------+---------+-------+
142+
# |spatial_key|column_index|row_index|tile |
143+
# +-----------+------------+---------+-------+
144+
# |[2,1] |4 |0 |10150.0|
145+
cell = self.rf.select(self.rf.spatial_key_column(), rf_explode_tiles(self.rf.tile)) \
146+
.where(F.col("spatial_key.col") == 2) \
147+
.where(F.col("spatial_key.row") == 1) \
148+
.where(F.col("column_index") == 4) \
149+
.where(F.col("row_index") == 0) \
150+
.select(F.col("tile")) \
151+
.collect()[0][0]
152+
self.assertEqual(cell, 10150.0)
153+
154+
# Test the sample version
155+
frac = 0.01
156+
sample_count = self.rf.select(rf_explode_tiles_sample(frac, 1872, 'tile')).count()
157+
print('Sample count is {}'.format(sample_count))
158+
self.assertTrue(sample_count > 0)
159+
self.assertTrue(sample_count < (frac * 1.1) * 387000) # give some wiggle room
160+
161+
def test_mask_by_value(self):
162+
from pyspark.sql.functions import lit
163+
164+
# create an artificial mask for values > 25000; masking value will be 4
165+
mask_value = 4
166+
167+
rf1 = self.rf.select(self.rf.tile,
168+
rf_local_multiply(
169+
rf_convert_cell_type(
170+
rf_local_greater_int(self.rf.tile, 25000),
171+
"uint8"),
172+
lit(mask_value)).alias('mask'))
173+
rf2 = rf1.select(rf1.tile, rf_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
174+
result = rf2.agg(rf_agg_no_data_cells(rf2.tile) < rf_agg_no_data_cells(rf2.masked)) \
175+
.collect()[0][0]
176+
self.assertTrue(result)
177+
178+
rf3 = rf1.select(rf1.tile, rf_inverse_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
179+
result = rf3.agg(rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked)) \
180+
.collect()[0][0]
181+
self.assertTrue(result)
182+
183+
def test_resample(self):
184+
from pyspark.sql.functions import lit
185+
result = self.rf.select(
186+
rf_tile_min(rf_local_equal(
187+
rf_resample(rf_resample(self.rf.tile, lit(2)), lit(0.5)),
188+
self.rf.tile))
189+
).collect()[0][0]
190+
191+
self.assertTrue(result == 1) # short hand for all values are true
192+
193+
def test_exists_for_all(self):
194+
df = self.rf.withColumn('should_exist', rf_make_ones_tile(5, 5, 'int8')) \
195+
.withColumn('should_not_exist', rf_make_zeros_tile(5, 5, 'int8'))
196+
197+
should_exist = df.select(rf_exists(df.should_exist).alias('se')).take(1)[0].se
198+
self.assertTrue(should_exist)
199+
200+
should_not_exist = df.select(rf_exists(df.should_not_exist).alias('se')).take(1)[0].se
201+
self.assertTrue(not should_not_exist)
202+
203+
self.assertTrue(df.select(rf_for_all(df.should_exist).alias('se')).take(1)[0].se)
204+
self.assertTrue(not df.select(rf_for_all(df.should_not_exist).alias('se')).take(1)[0].se)

0 commit comments

Comments
 (0)