@@ -409,7 +409,7 @@ def setUp(self):
409409 def test_raster_join (self ):
410410 # re-read the same source
411411 rf_prime = self .spark .read .geotiff (self .img_uri ) \
412- .withColumnRenamed ('tile' , 'tile2' ). alias ( 'rf_prime' )
412+ .withColumnRenamed ('tile' , 'tile2' )
413413
414414 rf_joined = self .rf .raster_join (rf_prime )
415415
@@ -428,18 +428,28 @@ def test_raster_join(self):
428428 self .assertTrue (rf_joined_3 .count (), self .rf .count ())
429429 self .assertTrue (len (rf_joined_3 .columns ) == len (self .rf .columns ) + len (rf_prime .columns ) - 2 )
430430
431- result_methods = self .rf \
432- .raster_join (rf_prime .withColumnRenamed ('tile2' , 'bilinear' ), "bilinear" ) \
433- .raster_join (rf_prime .withColumnRenamed ('tile2' , 'cubic_spline' ), "cubic_spline" ) \
431+ # throws if you don't pass in all expected columns
432+ with self .assertRaises (AssertionError ):
433+ self .rf .raster_join (rf_prime , join_exprs = self .rf .extent )
434+
435+ def test_raster_join_resample_method (self ):
436+ import os
437+ from pyspark .sql .functions import col
438+ df = self .spark .read .raster ('file://' + os .path .join (self .resource_dir , 'L8-B4-Elkton-VA.tiff' )) \
439+ .select (col ('proj_raster' ).alias ('tile' ))
440+ df_prime = self .spark .read .raster ('file://' + os .path .join (self .resource_dir , 'L8-B4-Elkton-VA-4326.tiff' )) \
441+ .select (col ('proj_raster' ).alias ('tile2' ))
442+
443+ result_methods = df \
444+ .raster_join (df_prime .withColumnRenamed ('tile2' , 'bilinear' ), resampling_method = "bilinear" ) \
445+ .select ('tile' , rf_proj_raster ('bilinear' , rf_extent ('tile' ), rf_crs ('tile' )).alias ('bilinear' )) \
446+ .raster_join (df_prime .withColumnRenamed ('tile2' , 'cubic_spline' ), resampling_method = "cubic_spline" ) \
434447 .select (rf_local_subtract ('bilinear' , 'cubic_spline' ).alias ('diff' )) \
435448 .agg (rf_agg_stats ('diff' ).alias ('stats' )) \
436449 .select ("stats.min" ) \
437450 .first ()
438- self .assertGreater (result_methods , 0.0 )
439451
440- # throws if you don't pass in all expected columns
441- with self .assertRaises (AssertionError ):
442- self .rf .raster_join (rf_prime , join_exprs = self .rf .extent )
452+ self .assertGreater (result_methods [0 ], 0.0 )
443453
444454 def test_raster_join_with_null_left_head (self ):
445455 # https://github.com/locationtech/rasterframes/issues/462
0 commit comments