Skip to content
Merged
12 changes: 12 additions & 0 deletions python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,18 @@ def __init__(

self.ctx = SessionContextInternal(config, runtime)

@classmethod
def global_ctx(cls) -> SessionContext:
"""Retrieve the global context as a `SessionContext` wrapper.

Returns:
A `SessionContext` object that wraps the global `SessionContextInternal`.
"""
internal_ctx = SessionContextInternal.global_ctx()
wrapper = cls()
wrapper.ctx = internal_ctx
return wrapper

def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.

Expand Down
63 changes: 27 additions & 36 deletions python/datafusion/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@

from typing import TYPE_CHECKING

from datafusion.context import SessionContext
from datafusion.dataframe import DataFrame

from ._internal import SessionContext as SessionContextInternal

if TYPE_CHECKING:
import pathlib

Expand Down Expand Up @@ -68,16 +67,14 @@ def read_parquet(
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_parquet(
str(path),
table_partition_cols,
parquet_pruning,
file_extension,
skip_metadata,
schema,
file_sort_order,
)
return SessionContext.global_ctx().read_parquet(
str(path),
table_partition_cols,
parquet_pruning,
file_extension,
skip_metadata,
schema,
file_sort_order,
)


Expand Down Expand Up @@ -110,15 +107,13 @@ def read_json(
"""
if table_partition_cols is None:
table_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_json(
str(path),
schema,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
return SessionContext.global_ctx().read_json(
str(path),
schema,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)


Expand Down Expand Up @@ -161,17 +156,15 @@ def read_csv(

path = [str(p) for p in path] if isinstance(path, list) else str(path)

return DataFrame(
SessionContextInternal._global_ctx().read_csv(
path,
schema,
has_header,
delimiter,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)
return SessionContext.global_ctx().read_csv(
path,
schema,
has_header,
delimiter,
schema_infer_max_records,
file_extension,
table_partition_cols,
file_compression_type,
)


Expand All @@ -198,8 +191,6 @@ def read_avro(
"""
if file_partition_cols is None:
file_partition_cols = []
return DataFrame(
SessionContextInternal._global_ctx().read_avro(
str(path), schema, file_partition_cols, file_extension
)
return SessionContext.global_ctx().read_avro(
str(path), schema, file_partition_cols, file_extension
)
18 changes: 18 additions & 0 deletions python/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,3 +632,21 @@ def test_sql_with_options_no_statements(ctx):
options = SQLOptions().with_allow_statements(allow=False)
with pytest.raises(Exception, match="SetVariable"):
ctx.sql_with_options(sql, options=options)


@pytest.fixture
def batch():
return pa.RecordBatch.from_arrays(
[pa.array([4, 5, 6])],
names=["a"],
)


def test_create_dataframe_with_global_ctx(batch):
ctx = SessionContext.global_ctx()

df = ctx.create_dataframe([[batch]])

result = df.collect()[0].column(0)

assert result == pa.array([4, 5, 6])
2 changes: 1 addition & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ impl PySessionContext {

#[classmethod]
#[pyo3(signature = ())]
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you have it here where you moved the single entry over to the python side, this method goes unused. I would recommend you leave this line as is, but up in the python code you call this method instead of creating _global_instance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments, just to summarize whats needed here:

  1. Expose the global context (_global_ctx -> global_ctx), which I've currently done.
  2. A python wrapper should be created for the global context (in the SessionContext class) which calls the above function and wraps it in SessionContext so that users can still use the other associated methods in this class, but with the global context. This should be a class method so that users dont have to instantiate SessionContext first.
  3. The read_* functions (read_parquet, etc) should use the global context from this python wrapper instead of using the one from the internal implementation.

Am I interpreting this correctly? Sorry if I'm overthinking this 😅. I've updated the PR, currently the test_read_csv and test_read_csv_list tests fail so I'm looking into that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this description looks correct.

Ok(Self {
ctx: get_global_ctx().clone(),
})
Expand Down