Skip to content

Commit 2581ca1

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-51058][PYTHON] Avoid using jvm.SparkSession
### What changes were proposed in this pull request? Avoids using `jvm.SparkSession` style, to improve Py4J performance similar to #49312, #49313, and #49412. ### Why are the changes needed? To reduce the overhead of Py4J calls. ```py import time def benchmark(f, _n=10, *args, **kwargs): start = time.time() for i in range(_n): f(*args, **kwargs) print(time.time() - start) ``` ```py from pyspark.context import SparkContext jvm = SparkContext._jvm def f(): return jvm.SparkSession benchmark(f, 10000) # -> 3.578310251235962 ``` ```py from pyspark.context import SparkContext jvm = SparkContext._jvm def g(): return getattr(jvm, "org.apache.spark.sql.classic.SparkSession") benchmark(g, 10000) # -> 0.254807710647583 ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? The existing tests should pass. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49760 from ueshin/issues/SPARK-51058/spark_session. Authored-by: Takuya Ueshin <ueshin@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 400e2b2 commit 2581ca1

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

python/pyspark/ml/connect/tuning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def single_task() -> Tuple[int, float]:
180180
if not is_remote():
181181
# Active session is thread-local variable, in background thread the active session
182182
# is not set, the following line sets it as the main thread active session.
183-
active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr]
183+
SparkSession._get_j_spark_session_class(active_session._jvm).setActiveSession(
184184
active_session._jsparkSession
185185
)
186186

python/pyspark/sql/session.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError
7272

7373
if TYPE_CHECKING:
74-
from py4j.java_gateway import JavaObject
74+
from py4j.java_gateway import JavaClass, JavaObject, JVMView
7575
import pyarrow as pa
7676
from pyspark.core.context import SparkContext
7777
from pyspark.core.rdd import RDD
@@ -542,9 +542,8 @@ def getOrCreate(self) -> "SparkSession":
542542
# by all sessions.
543543
session = SparkSession(sc, options=self._options)
544544
else:
545-
getattr(
546-
getattr(session._jvm, "SparkSession$"), "MODULE$"
547-
).applyModifiableSettings(session._jsparkSession, self._options)
545+
module = SparkSession._get_j_spark_session_module(session._jvm)
546+
module.applyModifiableSettings(session._jsparkSession, self._options)
548547
return session
549548

550549
# Spark Connect-specific API
@@ -612,21 +611,20 @@ def __init__(
612611

613612
assert self._jvm is not None
614613

614+
jSparkSessionClass = SparkSession._get_j_spark_session_class(self._jvm)
615+
jSparkSessionModule = SparkSession._get_j_spark_session_module(self._jvm)
616+
615617
if jsparkSession is None:
616618
if (
617-
self._jvm.SparkSession.getDefaultSession().isDefined()
618-
and not self._jvm.SparkSession.getDefaultSession().get().sparkContext().isStopped()
619+
jSparkSessionClass.getDefaultSession().isDefined()
620+
and not jSparkSessionClass.getDefaultSession().get().sparkContext().isStopped()
619621
):
620-
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
621-
getattr(getattr(self._jvm, "SparkSession$"), "MODULE$").applyModifiableSettings(
622-
jsparkSession, options
623-
)
622+
jsparkSession = jSparkSessionClass.getDefaultSession().get()
623+
jSparkSessionModule.applyModifiableSettings(jsparkSession, options)
624624
else:
625-
jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options)
625+
jsparkSession = jSparkSessionClass(self._jsc.sc(), options)
626626
else:
627-
getattr(getattr(self._jvm, "SparkSession$"), "MODULE$").applyModifiableSettings(
628-
jsparkSession, options
629-
)
627+
jSparkSessionModule.applyModifiableSettings(jsparkSession, options)
630628
self._jsparkSession = jsparkSession
631629
_monkey_patch_RDD(self)
632630
install_exception_handler()
@@ -637,8 +635,8 @@ def __init__(
637635
SparkSession._instantiatedSession = self
638636
SparkSession._activeSession = self
639637
assert self._jvm is not None
640-
self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
641-
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
638+
jSparkSessionClass.setDefaultSession(self._jsparkSession)
639+
jSparkSessionClass.setActiveSession(self._jsparkSession)
642640

643641
self._profiler_collector = AccumulatorProfilerCollector()
644642

@@ -649,6 +647,14 @@ def _should_update_active_session() -> bool:
649647
or SparkSession._instantiatedSession._sc._jsc is None
650648
)
651649

650+
@staticmethod
651+
def _get_j_spark_session_class(jvm: "JVMView") -> "JavaClass":
652+
return getattr(jvm, "org.apache.spark.sql.classic.SparkSession")
653+
654+
@staticmethod
655+
def _get_j_spark_session_module(jvm: "JVMView") -> "JavaObject":
656+
return getattr(getattr(jvm, "org.apache.spark.sql.classic.SparkSession$"), "MODULE$")
657+
652658
def _repr_html_(self) -> str:
653659
return """
654660
<div>
@@ -721,9 +727,10 @@ def getActiveSession(cls) -> Optional["SparkSession"]:
721727
return None
722728
else:
723729
assert sc._jvm is not None
724-
if sc._jvm.SparkSession.getActiveSession().isDefined():
730+
jSparkSessionClass = SparkSession._get_j_spark_session_class(sc._jvm)
731+
if jSparkSessionClass.getActiveSession().isDefined():
725732
if SparkSession._should_update_active_session():
726-
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
733+
SparkSession(sc, jSparkSessionClass.getActiveSession().get())
727734
return SparkSession._activeSession
728735
else:
729736
return None
@@ -1501,7 +1508,7 @@ def createDataFrame( # type: ignore[misc]
15011508
"""
15021509
SparkSession._activeSession = self
15031510
assert self._jvm is not None
1504-
self._jvm.SparkSession.setActiveSession(self._jsparkSession)
1511+
SparkSession._get_j_spark_session_class(self._jvm).setActiveSession(self._jsparkSession)
15051512
if isinstance(data, DataFrame):
15061513
raise PySparkTypeError(
15071514
errorClass="INVALID_TYPE",
@@ -2000,8 +2007,9 @@ def stop(self) -> None:
20002007
self._sc.stop()
20012008
# We should clean the default session up. See SPARK-23228.
20022009
assert self._jvm is not None
2003-
self._jvm.SparkSession.clearDefaultSession()
2004-
self._jvm.SparkSession.clearActiveSession()
2010+
jSparkSessionClass = SparkSession._get_j_spark_session_class(self._jvm)
2011+
jSparkSessionClass.clearDefaultSession()
2012+
jSparkSessionClass.clearActiveSession()
20052013
SparkSession._instantiatedSession = None
20062014
SparkSession._activeSession = None
20072015
SQLContext._instantiatedContext = None

0 commit comments

Comments
 (0)