|
| 1 | +# |
| 2 | +# This software is licensed under the Apache 2 license, quoted below. |
| 3 | +# |
| 4 | +# Copyright 2019 Astraea, Inc. |
| 5 | +# |
| 6 | +# Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| 7 | +# use this file except in compliance with the License. You may obtain a copy of |
| 8 | +# the License at |
| 9 | +# |
| 10 | +# [http://www.apache.org/licenses/LICENSE-2.0] |
| 11 | +# |
| 12 | +# Unless required by applicable law or agreed to in writing, software |
| 13 | +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| 14 | +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| 15 | +# License for the specific language governing permissions and limitations under |
| 16 | +# the License. |
| 17 | +# |
| 18 | +# SPDX-License-Identifier: Apache-2.0 |
| 19 | +# |
| 20 | + |
| 21 | +import unittest |
| 22 | + |
| 23 | +import numpy as np |
| 24 | +from pyrasterframes.rasterfunctions import * |
| 25 | +from pyrasterframes.rf_types import * |
| 26 | +from pyspark.sql import SQLContext |
| 27 | +from pyspark.sql.functions import * |
| 28 | +from . import TestEnvironment |
| 29 | + |
| 30 | + |
| 31 | +class RasterFunctions(TestEnvironment): |
| 32 | + |
| 33 | + def setUp(self): |
| 34 | + self.create_layer() |
| 35 | + |
| 36 | + def test_setup(self): |
| 37 | + self.assertEqual(self.spark.sparkContext.getConf().get("spark.serializer"), |
| 38 | + "org.apache.spark.serializer.KryoSerializer") |
| 39 | + |
| 40 | + def test_identify_columns(self): |
| 41 | + cols = self.rf.tile_columns() |
| 42 | + self.assertEqual(len(cols), 1, '`tileColumns` did not find the proper number of columns.') |
| 43 | + print("Tile columns: ", cols) |
| 44 | + col = self.rf.spatial_key_column() |
| 45 | + self.assertIsInstance(col, Column, '`spatialKeyColumn` was not found') |
| 46 | + print("Spatial key column: ", col) |
| 47 | + col = self.rf.temporal_key_column() |
| 48 | + self.assertIsNone(col, '`temporalKeyColumn` should be `None`') |
| 49 | + print("Temporal key column: ", col) |
| 50 | + |
| 51 | + def test_tile_creation(self): |
| 52 | + base = self.spark.createDataFrame([1, 2, 3, 4], 'integer') |
| 53 | + tiles = base.select(rf_make_constant_tile(3, 3, 3, "int32"), rf_make_zeros_tile(3, 3, "int32"), |
| 54 | + rf_make_ones_tile(3, 3, "int32")) |
| 55 | + tiles.show() |
| 56 | + self.assertEqual(tiles.count(), 4) |
| 57 | + |
| 58 | + def test_multi_column_operations(self): |
| 59 | + df1 = self.rf.withColumnRenamed('tile', 't1').as_layer() |
| 60 | + df2 = self.rf.withColumnRenamed('tile', 't2').as_layer() |
| 61 | + df3 = df1.spatial_join(df2).as_layer() |
| 62 | + df3 = df3.withColumn('norm_diff', rf_normalized_difference('t1', 't2')) |
| 63 | + # df3.printSchema() |
| 64 | + |
| 65 | + aggs = df3.agg( |
| 66 | + rf_agg_mean('norm_diff'), |
| 67 | + ) |
| 68 | + aggs.show() |
| 69 | + row = aggs.first() |
| 70 | + |
| 71 | + self.assertTrue(self.rounded_compare(row['rf_agg_mean(norm_diff)'], 0)) |
| 72 | + |
| 73 | + def test_general(self): |
| 74 | + meta = self.rf.tile_layer_metadata() |
| 75 | + self.assertIsNotNone(meta['bounds']) |
| 76 | + df = self.rf.withColumn('dims', rf_dimensions('tile')) \ |
| 77 | + .withColumn('type', rf_cell_type('tile')) \ |
| 78 | + .withColumn('dCells', rf_data_cells('tile')) \ |
| 79 | + .withColumn('ndCells', rf_no_data_cells('tile')) \ |
| 80 | + .withColumn('min', rf_tile_min('tile')) \ |
| 81 | + .withColumn('max', rf_tile_max('tile')) \ |
| 82 | + .withColumn('mean', rf_tile_mean('tile')) \ |
| 83 | + .withColumn('sum', rf_tile_sum('tile')) \ |
| 84 | + .withColumn('stats', rf_tile_stats('tile')) \ |
| 85 | + .withColumn('extent', st_extent('geometry')) \ |
| 86 | + .withColumn('extent_geom1', st_geometry('extent')) \ |
| 87 | + .withColumn('ascii', rf_render_ascii('tile')) \ |
| 88 | + .withColumn('log', rf_log('tile')) \ |
| 89 | + .withColumn('exp', rf_exp('tile')) \ |
| 90 | + .withColumn('expm1', rf_expm1('tile')) \ |
| 91 | + .withColumn('round', rf_round('tile')) \ |
| 92 | + .withColumn('abs', rf_abs('tile')) |
| 93 | + |
| 94 | + df.first() |
| 95 | + |
| 96 | + def test_agg_mean(self): |
| 97 | + mean = self.rf.agg(rf_agg_mean('tile')).first()['rf_agg_mean(tile)'] |
| 98 | + self.assertTrue(self.rounded_compare(mean, 10160)) |
| 99 | + |
| 100 | + def test_aggregations(self): |
| 101 | + aggs = self.rf.agg( |
| 102 | + rf_agg_data_cells('tile'), |
| 103 | + rf_agg_no_data_cells('tile'), |
| 104 | + rf_agg_stats('tile'), |
| 105 | + rf_agg_approx_histogram('tile') |
| 106 | + ) |
| 107 | + row = aggs.first() |
| 108 | + |
| 109 | + # print(row['rf_agg_data_cells(tile)']) |
| 110 | + self.assertEqual(row['rf_agg_data_cells(tile)'], 387000) |
| 111 | + self.assertEqual(row['rf_agg_no_data_cells(tile)'], 1000) |
| 112 | + self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)']) |
| 113 | + |
| 114 | + def test_sql(self): |
| 115 | + self.rf.createOrReplaceTempView("rf_test_sql") |
| 116 | + |
| 117 | + self.spark.sql("""SELECT tile, |
| 118 | + rf_local_add(tile, 1) AS and_one, |
| 119 | + rf_local_subtract(tile, 1) AS less_one, |
| 120 | + rf_local_multiply(tile, 2) AS times_two, |
| 121 | + rf_local_divide(tile, 2) AS over_two |
| 122 | + FROM rf_test_sql""").createOrReplaceTempView('rf_test_sql_1') |
| 123 | + |
| 124 | + statsRow = self.spark.sql(""" |
| 125 | + SELECT rf_tile_mean(tile) as base, |
| 126 | + rf_tile_mean(and_one) as plus_one, |
| 127 | + rf_tile_mean(less_one) as minus_one, |
| 128 | + rf_tile_mean(times_two) as double, |
| 129 | + rf_tile_mean(over_two) as half |
| 130 | + FROM rf_test_sql_1 |
| 131 | + """).first() |
| 132 | + |
| 133 | + self.assertTrue(self.rounded_compare(statsRow.base, statsRow.plus_one - 1)) |
| 134 | + self.assertTrue(self.rounded_compare(statsRow.base, statsRow.minus_one + 1)) |
| 135 | + self.assertTrue(self.rounded_compare(statsRow.base, statsRow.double / 2)) |
| 136 | + self.assertTrue(self.rounded_compare(statsRow.base, statsRow.half * 2)) |
| 137 | + |
| 138 | + def test_explode(self): |
| 139 | + import pyspark.sql.functions as F |
| 140 | + self.rf.select('spatial_key', rf_explode_tiles('tile')).show() |
| 141 | + # +-----------+------------+---------+-------+ |
| 142 | + # |spatial_key|column_index|row_index|tile | |
| 143 | + # +-----------+------------+---------+-------+ |
| 144 | + # |[2,1] |4 |0 |10150.0| |
| 145 | + cell = self.rf.select(self.rf.spatial_key_column(), rf_explode_tiles(self.rf.tile)) \ |
| 146 | + .where(F.col("spatial_key.col") == 2) \ |
| 147 | + .where(F.col("spatial_key.row") == 1) \ |
| 148 | + .where(F.col("column_index") == 4) \ |
| 149 | + .where(F.col("row_index") == 0) \ |
| 150 | + .select(F.col("tile")) \ |
| 151 | + .collect()[0][0] |
| 152 | + self.assertEqual(cell, 10150.0) |
| 153 | + |
| 154 | + # Test the sample version |
| 155 | + frac = 0.01 |
| 156 | + sample_count = self.rf.select(rf_explode_tiles_sample(frac, 1872, 'tile')).count() |
| 157 | + print('Sample count is {}'.format(sample_count)) |
| 158 | + self.assertTrue(sample_count > 0) |
| 159 | + self.assertTrue(sample_count < (frac * 1.1) * 387000) # give some wiggle room |
| 160 | + |
| 161 | + def test_mask_by_value(self): |
| 162 | + from pyspark.sql.functions import lit |
| 163 | + |
| 164 | + # create an artificial mask for values > 25000; masking value will be 4 |
| 165 | + mask_value = 4 |
| 166 | + |
| 167 | + rf1 = self.rf.select(self.rf.tile, |
| 168 | + rf_local_multiply( |
| 169 | + rf_convert_cell_type( |
| 170 | + rf_local_greater_int(self.rf.tile, 25000), |
| 171 | + "uint8"), |
| 172 | + lit(mask_value)).alias('mask')) |
| 173 | + rf2 = rf1.select(rf1.tile, rf_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked')) |
| 174 | + result = rf2.agg(rf_agg_no_data_cells(rf2.tile) < rf_agg_no_data_cells(rf2.masked)) \ |
| 175 | + .collect()[0][0] |
| 176 | + self.assertTrue(result) |
| 177 | + |
| 178 | + rf3 = rf1.select(rf1.tile, rf_inverse_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked')) |
| 179 | + result = rf3.agg(rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked)) \ |
| 180 | + .collect()[0][0] |
| 181 | + self.assertTrue(result) |
| 182 | + |
| 183 | + def test_resample(self): |
| 184 | + from pyspark.sql.functions import lit |
| 185 | + result = self.rf.select( |
| 186 | + rf_tile_min(rf_local_equal( |
| 187 | + rf_resample(rf_resample(self.rf.tile, lit(2)), lit(0.5)), |
| 188 | + self.rf.tile)) |
| 189 | + ).collect()[0][0] |
| 190 | + |
| 191 | + self.assertTrue(result == 1) # short hand for all values are true |
| 192 | + |
| 193 | + def test_exists_for_all(self): |
| 194 | + df = self.rf.withColumn('should_exist', rf_make_ones_tile(5, 5, 'int8')) \ |
| 195 | + .withColumn('should_not_exist', rf_make_zeros_tile(5, 5, 'int8')) |
| 196 | + |
| 197 | + should_exist = df.select(rf_exists(df.should_exist).alias('se')).take(1)[0].se |
| 198 | + self.assertTrue(should_exist) |
| 199 | + |
| 200 | + should_not_exist = df.select(rf_exists(df.should_not_exist).alias('se')).take(1)[0].se |
| 201 | + self.assertTrue(not should_not_exist) |
| 202 | + |
| 203 | + self.assertTrue(df.select(rf_for_all(df.should_exist).alias('se')).take(1)[0].se) |
| 204 | + self.assertTrue(not df.select(rf_for_all(df.should_not_exist).alias('se')).take(1)[0].se) |
0 commit comments