diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 0ab1a908a..58ad9a943 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -496,6 +496,18 @@ def __init__( self.ctx = SessionContextInternal(config, runtime) + @classmethod + def global_ctx(cls) -> SessionContext: + """Retrieve the global context as a `SessionContext` wrapper. + + Returns: + A `SessionContext` object that wraps the global `SessionContextInternal`. + """ + internal_ctx = SessionContextInternal.global_ctx() + wrapper = cls() + wrapper.ctx = internal_ctx + return wrapper + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 3e39703e3..ef5ebf96f 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -21,10 +21,9 @@ from typing import TYPE_CHECKING +from datafusion.context import SessionContext from datafusion.dataframe import DataFrame -from ._internal import SessionContext as SessionContextInternal - if TYPE_CHECKING: import pathlib @@ -68,16 +67,14 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] - return DataFrame( - SessionContextInternal._global_ctx().read_parquet( - str(path), - table_partition_cols, - parquet_pruning, - file_extension, - skip_metadata, - schema, - file_sort_order, - ) + return SessionContext.global_ctx().read_parquet( + str(path), + table_partition_cols, + parquet_pruning, + file_extension, + skip_metadata, + schema, + file_sort_order, ) @@ -110,15 +107,13 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] - return DataFrame( - SessionContextInternal._global_ctx().read_json( - str(path), - schema, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + return SessionContext.global_ctx().read_json( + str(path), + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, ) @@ -161,17 +156,15 @@ def read_csv( path = [str(p) for p in path] if isinstance(path, list) else str(path) - return DataFrame( - SessionContextInternal._global_ctx().read_csv( - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, - ) + return SessionContext.global_ctx().read_csv( + path, + schema, + has_header, + delimiter, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, ) @@ -198,8 +191,6 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] - return DataFrame( - SessionContextInternal._global_ctx().read_avro( - str(path), schema, file_partition_cols, file_extension - ) + return SessionContext.global_ctx().read_avro( + str(path), schema, file_partition_cols, file_extension ) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 7a0a7aa08..4a15ac9cf 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -632,3 +632,21 @@ def test_sql_with_options_no_statements(ctx): options = SQLOptions().with_allow_statements(allow=False) with pytest.raises(Exception, match="SetVariable"): ctx.sql_with_options(sql, options=options) + + +@pytest.fixture +def batch(): + return pa.RecordBatch.from_arrays( + [pa.array([4, 5, 6])], + names=["a"], + ) + + +def test_create_dataframe_with_global_ctx(batch): + ctx = SessionContext.global_ctx() + + df = ctx.create_dataframe([[batch]]) + + result = df.collect()[0].column(0) + + assert result == pa.array([4, 5, 6]) diff --git a/src/context.rs b/src/context.rs index 9ba87eb8a..0db0f4d7e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -308,7 +308,7 @@ impl PySessionContext { #[classmethod] #[pyo3(signature = ())] - fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { + fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { Ok(Self { ctx: get_global_ctx().clone(), })