Skip to content

Commit 3751cf5

Browse files
committed
singleton for global context
1 parent 597ef33 commit 3751cf5

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

python/datafusion/context.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,8 @@ class SessionContext:
468468
See :ref:`user_guide_concepts` in the online documentation for more information.
469469
"""
470470

471+
_global_instance = None
472+
471473
def __init__(
472474
self,
473475
config: SessionConfig | None = None,
@@ -505,7 +507,10 @@ def global_ctx(cls) -> "SessionContextInternal":
505507
Returns:
506508
A `SessionContextInternal` object that corresponds to the global context
507509
"""
508-
return SessionContextInternal.global_ctx()
510+
if cls._global_instance is None:
511+
internal_ctx = SessionContextInternal.global_ctx()
512+
cls._global_instance = internal_ctx
513+
return cls._global_instance
509514

510515
def enable_url_table(self) -> "SessionContext":
511516
"""Control if local files can be queried as tables.

python/tests/test_context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,11 @@
3030
SQLOptions,
3131
column,
3232
literal,
33+
udf,
3334
)
3435

36+
from datafusion._internal import SessionContext as SessionContextInternal
37+
3538

3639
def test_create_context_no_args():
3740
SessionContext()
@@ -629,3 +632,32 @@ def test_sql_with_options_no_statements(ctx):
629632
options = SQLOptions().with_allow_statements(False)
630633
with pytest.raises(Exception, match="SetVariable"):
631634
ctx.sql_with_options(sql, options=options)
635+
636+
637+
def test_global_context_type():
638+
ctx = SessionContext.global_ctx()
639+
assert isinstance(ctx, SessionContextInternal)
640+
641+
642+
def test_global_context_is_singleton():
643+
ctx1 = SessionContext.global_ctx()
644+
ctx2 = SessionContext.global_ctx()
645+
assert ctx1 is ctx2
646+
647+
648+
@pytest.fixture
649+
def batch():
650+
return pa.RecordBatch.from_arrays(
651+
[pa.array([4, 5, 6])],
652+
names=["a"],
653+
)
654+
655+
656+
def test_create_dataframe_with_global_ctx(batch):
657+
ctx = SessionContext.global_ctx()
658+
659+
df = ctx.create_dataframe([[batch]])
660+
661+
result = df.collect()[0].column(0)
662+
663+
assert result == pa.array([4, 5, 6])

0 commit comments

Comments
 (0)