File tree Expand file tree Collapse file tree 4 files changed +24
-4
lines changed
pyrasterframes/src/main/python Expand file tree Collapse file tree 4 files changed +24
-4
lines changed Original file line number Diff line number Diff line change @@ -55,11 +55,11 @@ def _kryo_init(builder):
5555 return builder
5656
5757
58- def get_spark_session (master = "local[*]" ):
58+ def get_spark_session (master = "local[*]" , ** kwargs ):
5959 """ Create a SparkSession with pyrasterframes enabled and configured. """
6060 from pyrasterframes .utils import create_rf_spark_session
6161
62- return create_rf_spark_session (master )
62+ return create_rf_spark_session (master , ** kwargs )
6363
6464
6565def _convert_df (df , sp_key = None , metadata = None ):
Original file line number Diff line number Diff line change 2020
2121import glob
2222from pyspark .sql import SparkSession
23+ from pyspark import SparkConf
2324import os
2425import sys
2526from . import RFContext
@@ -76,15 +77,22 @@ def find_pyrasterframes_assembly():
7677 return jarpath [0 ]
7778
7879
79- def create_rf_spark_session (master = "local[*]" ):
80+ def create_rf_spark_session (master = "local[*]" , ** kwargs ):
8081 """ Create a SparkSession with pyrasterframes enabled and configured. """
8182 jar_path = find_pyrasterframes_assembly ()
8283
84+ if 'spark.jars' in kwargs .keys ():
85+ if 'pyrasterframes' not in kwargs ['spark.jars' ]:
86+ raise UserWarning ("spark.jars config is set, but it seems to be missing the pyrasterframes assembly jar." )
87+
88+ conf = SparkConf ().setAll ([(k , kwargs [k ]) for k in kwargs ])
89+
8390 spark = (SparkSession .builder
8491 .master (master )
8592 .appName ("RasterFrames" )
8693 .config ('spark.jars' , jar_path )
8794 .withKryoSerialization ()
95+ .config (conf = conf ) # user can override the defaults
8896 .getOrCreate ())
8997
9098 try :
Original file line number Diff line number Diff line change 2828from . import TestEnvironment
2929
3030
31+ class UtilTest (TestEnvironment ):
32+
33+ def test_spark_confs (self ):
34+ from . import app_name
35+ self .assertEqual (self .spark .conf .get ('spark.app.name' ), app_name )
36+ self .assertEqual (self .spark .conf .get ('spark.ui.enabled' ), 'false' )
37+
38+
3139class CellTypeHandling (unittest .TestCase ):
3240
3341 def test_is_raw (self ):
Original file line number Diff line number Diff line change 3131else :
3232 import __builtin__ as builtins
3333
34+ app_name = 'pyrasterframes test suite'
3435
3536def resource_dir ():
3637 def pdir (curr ):
@@ -46,7 +47,10 @@ def pdir(curr):
4647
4748
4849def spark_test_session ():
49- spark = create_rf_spark_session ()
50+ spark = create_rf_spark_session (** {
51+ 'spark.ui.enabled' : 'false' ,
52+ 'spark.app.name' : app_name
53+ })
5054 spark .sparkContext .setLogLevel ('ERROR' )
5155
5256 print ("Spark Version: " + spark .version )
You can’t perform that action at this time.
0 commit comments