@@ -228,28 +228,23 @@ def test_aggregations(self):
228228 self .assertEqual (row ['rf_agg_stats(tile)' ].data_cells , row ['rf_agg_data_cells(tile)' ])
229229
230230 def test_sql (self ):
231- self .rf .createOrReplaceTempView ("rf" )
232-
233- dims = self .rf .withColumn ('dims' , rf_dimensions ('tile' )).first ().dims
234- dims_str = """{}, {}""" .format (dims .cols , dims .rows )
235-
236- self .spark .sql ("""SELECT tile, rf_make_constant_tile(1, {}, 'uint16') AS One,
237- rf_make_constant_tile(2, {}, 'uint16') AS Two FROM rf""" .format (dims_str , dims_str )) \
238- .createOrReplaceTempView ("r3" )
239-
240- ops = self .spark .sql ("""SELECT tile, rf_local_add(tile, One) AS AndOne,
241- rf_local_subtract(tile, One) AS LessOne,
242- rf_local_multiply(tile, Two) AS TimesTwo,
243- rf_local_divide(tile, Two) AS OverTwo
244- FROM r3""" )
245-
246- # ops.printSchema
247- statsRow = ops .select (rf_tile_mean ('tile' ).alias ('base' ),
248- rf_tile_mean ("AndOne" ).alias ('plus_one' ),
249- rf_tile_mean ("LessOne" ).alias ('minus_one' ),
250- rf_tile_mean ("TimesTwo" ).alias ('double' ),
251- rf_tile_mean ("OverTwo" ).alias ('half' )) \
252- .first ()
231+ self .rf .createOrReplaceTempView ("rf_test_sql" )
232+
233+ self .spark .sql ("""SELECT tile,
234+ rf_local_add(tile, 1) AS and_one,
235+ rf_local_subtract(tile, 1) AS less_one,
236+ rf_local_multiply(tile, 2) AS times_two,
237+ rf_local_divide(tile, 2) AS over_two
238+ FROM rf_test_sql""" ).createOrReplaceTempView ('rf_test_sql_1' )
239+
240+ statsRow = self .spark .sql ("""
241+ SELECT rf_tile_mean(tile) as base,
242+ rf_tile_mean(and_one) as plus_one,
243+ rf_tile_mean(less_one) as minus_one,
244+ rf_tile_mean(times_two) as double,
245+ rf_tile_mean(over_two) as half
246+ FROM rf_test_sql_1
247+ """ ).first ()
253248
254249 self .assertTrue (self .rounded_compare (statsRow .base , statsRow .plus_one - 1 ))
255250 self .assertTrue (self .rounded_compare (statsRow .base , statsRow .minus_one + 1 ))
@@ -532,8 +527,6 @@ def less_pi(t):
532527
533528class TileOps (TestEnvironment ):
534529
535- from pyrasterframes .rf_types import Tile
536-
537530 def setUp (self ):
538531 # convenience so we can assert around Tile() == Tile()
539532 self .t1 = Tile (np .array ([[1 , 2 ],
@@ -589,9 +582,11 @@ def test_matmul(self):
589582 # r1 = self.t1 @ self.t2
590583 r1 = self .t1 .__matmul__ (self .t2 )
591584
592- nd = r1 .cell_type .no_data_value ()
593- e1 = Tile (np .ma .masked_equal (np .array ([[nd , 10 ],
594- [nd , nd ]], dtype = r1 .cell_type .to_numpy_dtype ()), nd ))
585+ # The behavior of np.matmul with masked arrays is not well documented
586+ # it seems to treat the 2nd arg as if not a MaskedArray
587+ e1 = Tile (np .matmul (self .t1 .cells , self .t2 .cells ), r1 .cell_type )
588+
589+ self .assertTrue (r1 == e1 , "{} was not equal to {}" .format (r1 , e1 ))
595590 self .assertEqual (r1 , e1 )
596591
597592
@@ -714,7 +709,7 @@ def test_strict_eval(self):
714709 # again for strict
715710 df_strict = self .spark .read .raster (self .img_uri , lazy_tiles = False )
716711 show_str_strict = df_strict .select ('proj_raster' )._jdf .showString (1 , - 1 , False )
717- self .assertTrue ('RasterRef' not in show_str_lazy )
712+ self .assertTrue ('RasterRef' not in show_str_strict )
718713
719714
720715 def test_prt_functions (self ):
0 commit comments