Skip to content

Commit ee91e6d

Browse files
committed
Pass kwargs in python create_rf_spark_session to spark conf
Signed-off-by: Jason T. Brown <[email protected]>
1 parent f4a9a7c commit ee91e6d

File tree

4 files changed

+24
-4
lines changed

4 files changed

+24
-4
lines changed

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def _kryo_init(builder):
5555
return builder
5656

5757

58-
def get_spark_session(master="local[*]"):
58+
def get_spark_session(master="local[*]", **kwargs):
5959
""" Create a SparkSession with pyrasterframes enabled and configured. """
6060
from pyrasterframes.utils import create_rf_spark_session
6161

62-
return create_rf_spark_session(master)
62+
return create_rf_spark_session(master, **kwargs)
6363

6464

6565
def _convert_df(df, sp_key=None, metadata=None):

pyrasterframes/src/main/python/pyrasterframes/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import glob
2222
from pyspark.sql import SparkSession
23+
from pyspark import SparkConf
2324
import os
2425
import sys
2526
from . import RFContext
@@ -76,15 +77,22 @@ def find_pyrasterframes_assembly():
7677
return jarpath[0]
7778

7879

79-
def create_rf_spark_session(master="local[*]"):
80+
def create_rf_spark_session(master="local[*]", **kwargs):
8081
""" Create a SparkSession with pyrasterframes enabled and configured. """
8182
jar_path = find_pyrasterframes_assembly()
8283

84+
if 'spark.jars' in kwargs.keys():
85+
if 'pyrasterframes' not in kwargs['spark.jars']:
86+
raise UserWarning("spark.jars config is set, but it seems to be missing the pyrasterframes assembly jar.")
87+
88+
conf = SparkConf().setAll([(k, kwargs[k]) for k in kwargs])
89+
8390
spark = (SparkSession.builder
8491
.master(master)
8592
.appName("RasterFrames")
8693
.config('spark.jars', jar_path)
8794
.withKryoSerialization()
95+
.config(conf=conf) # user can override the defaults
8896
.getOrCreate())
8997

9098
try:

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
from . import TestEnvironment
2929

3030

31+
class UtilTest(TestEnvironment):
32+
33+
def test_spark_confs(self):
34+
from . import app_name
35+
self.assertEqual(self.spark.conf.get('spark.app.name'), app_name)
36+
self.assertEqual(self.spark.conf.get('spark.ui.enabled'), 'false')
37+
38+
3139
class CellTypeHandling(unittest.TestCase):
3240

3341
def test_is_raw(self):

pyrasterframes/src/main/python/tests/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
else:
3232
import __builtin__ as builtins
3333

34+
app_name = 'pyrasterframes test suite'
3435

3536
def resource_dir():
3637
def pdir(curr):
@@ -46,7 +47,10 @@ def pdir(curr):
4647

4748

4849
def spark_test_session():
49-
spark = create_rf_spark_session()
50+
spark = create_rf_spark_session(**{
51+
'spark.ui.enabled': 'false',
52+
'spark.app.name': app_name
53+
})
5054
spark.sparkContext.setLogLevel('ERROR')
5155

5256
print("Spark Version: " + spark.version)

0 commit comments

Comments
 (0)