Skip to content

Commit a29f103

Browse files
committed
feat: reads using global ctx
1 parent 79c22d6 commit a29f103

File tree

7 files changed

+304
-2
lines changed

7 files changed

+304
-2
lines changed

python/datafusion/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848

4949
from .dataframe import DataFrame
5050

51+
from .io import read_parquet
52+
5153
from .expr import (
5254
Expr,
5355
WindowFrame,
@@ -89,6 +91,7 @@
8991
"functions",
9092
"object_store",
9193
"substrait",
94+
"read_parquet",
9295
]
9396

9497

python/datafusion/io.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""IO read functions using global context."""
19+
20+
import pathlib
21+
22+
from datafusion.dataframe import DataFrame
23+
from datafusion.expr import Expr
24+
import pyarrow
25+
from ._internal import SessionContext as SessionContextInternal
26+
27+
28+
def read_parquet(
29+
path: str | pathlib.Path,
30+
table_partition_cols: list[tuple[str, str]] | None = None,
31+
parquet_pruning: bool = True,
32+
file_extension: str = ".parquet",
33+
skip_metadata: bool = True,
34+
schema: pyarrow.Schema | None = None,
35+
file_sort_order: list[list[Expr]] | None = None,
36+
) -> DataFrame:
37+
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
38+
39+
Args:
40+
path: Path to the Parquet file.
41+
table_partition_cols: Partition columns.
42+
parquet_pruning: Whether the parquet reader should use the predicate
43+
to prune row groups.
44+
file_extension: File extension; only files with this extension are
45+
selected for data input.
46+
skip_metadata: Whether the parquet reader should skip any metadata
47+
that may be in the file schema. This can help avoid schema
48+
conflicts due to metadata.
49+
schema: An optional schema representing the parquet files. If None,
50+
the parquet reader will try to infer it based on data in the
51+
file.
52+
file_sort_order: Sort order for the file.
53+
54+
Returns:
55+
DataFrame representation of the read Parquet files
56+
"""
57+
if table_partition_cols is None:
58+
table_partition_cols = []
59+
return DataFrame(
60+
SessionContextInternal._global_ctx().read_parquet(
61+
str(path),
62+
table_partition_cols,
63+
parquet_pruning,
64+
file_extension,
65+
skip_metadata,
66+
schema,
67+
file_sort_order,
68+
)
69+
)
70+
71+
72+
def read_json(
73+
path: str | pathlib.Path,
74+
schema: pyarrow.Schema | None = None,
75+
schema_infer_max_records: int = 1000,
76+
file_extension: str = ".json",
77+
table_partition_cols: list[tuple[str, str]] | None = None,
78+
file_compression_type: str | None = None,
79+
) -> DataFrame:
80+
"""Read a line-delimited JSON data source.
81+
82+
Args:
83+
path: Path to the JSON file.
84+
schema: The data source schema.
85+
schema_infer_max_records: Maximum number of rows to read from JSON
86+
files for schema inference if needed.
87+
file_extension: File extension; only files with this extension are
88+
selected for data input.
89+
table_partition_cols: Partition columns.
90+
file_compression_type: File compression type.
91+
92+
Returns:
93+
DataFrame representation of the read JSON files.
94+
"""
95+
if table_partition_cols is None:
96+
table_partition_cols = []
97+
return DataFrame(
98+
SessionContextInternal._global_ctx().read_json(
99+
str(path),
100+
schema,
101+
schema_infer_max_records,
102+
file_extension,
103+
table_partition_cols,
104+
file_compression_type,
105+
)
106+
)
107+
108+
109+
def read_csv(
110+
path: str | pathlib.Path | list[str] | list[pathlib.Path],
111+
schema: pyarrow.Schema | None = None,
112+
has_header: bool = True,
113+
delimiter: str = ",",
114+
schema_infer_max_records: int = 1000,
115+
file_extension: str = ".csv",
116+
table_partition_cols: list[tuple[str, str]] | None = None,
117+
file_compression_type: str | None = None,
118+
) -> DataFrame:
119+
"""Read a CSV data source.
120+
121+
Args:
122+
path: Path to the CSV file
123+
schema: An optional schema representing the CSV files. If None, the
124+
CSV reader will try to infer it based on data in file.
125+
has_header: Whether the CSV file have a header. If schema inference
126+
is run on a file with no headers, default column names are
127+
created.
128+
delimiter: An optional column delimiter.
129+
schema_infer_max_records: Maximum number of rows to read from CSV
130+
files for schema inference if needed.
131+
file_extension: File extension; only files with this extension are
132+
selected for data input.
133+
table_partition_cols: Partition columns.
134+
file_compression_type: File compression type.
135+
136+
Returns:
137+
DataFrame representation of the read CSV files
138+
"""
139+
if table_partition_cols is None:
140+
table_partition_cols = []
141+
142+
path = [str(p) for p in path] if isinstance(path, list) else str(path)
143+
144+
return DataFrame(
145+
SessionContextInternal._global_ctx().read_csv(
146+
path,
147+
schema,
148+
has_header,
149+
delimiter,
150+
schema_infer_max_records,
151+
file_extension,
152+
table_partition_cols,
153+
file_compression_type,
154+
)
155+
)
156+
157+
158+
def read_avro(
159+
path: str | pathlib.Path,
160+
schema: pyarrow.Schema | None = None,
161+
file_partition_cols: list[tuple[str, str]] | None = None,
162+
file_extension: str = ".avro",
163+
) -> DataFrame:
164+
"""Create a :py:class:`DataFrame` for reading Avro data source.
165+
166+
Args:
167+
path: Path to the Avro file.
168+
schema: The data source schema.
169+
file_partition_cols: Partition columns.
170+
file_extension: File extension to select.
171+
172+
Returns:
173+
DataFrame representation of the read Avro file
174+
"""
175+
if file_partition_cols is None:
176+
file_partition_cols = []
177+
return DataFrame(
178+
SessionContextInternal._global_ctx().read_avro(
179+
str(path), schema, file_partition_cols, file_extension
180+
)
181+
)

python/tests/test_context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import pyarrow.dataset as ds
2424
import pytest
2525

26+
2627
from datafusion import (
2728
DataFrame,
2829
RuntimeConfig,

python/tests/test_io.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import os
18+
import pathlib
19+
20+
from datafusion import column
21+
import pyarrow as pa
22+
23+
24+
from datafusion.io import read_avro, read_csv, read_json, read_parquet
25+
26+
27+
def test_read_json_global_ctx(ctx):
28+
path = os.path.dirname(os.path.abspath(__file__))
29+
30+
# Default
31+
test_data_path = os.path.join(path, "data_test_context", "data.json")
32+
df = read_json(test_data_path)
33+
result = df.collect()
34+
35+
assert result[0].column(0) == pa.array(["a", "b", "c"])
36+
assert result[0].column(1) == pa.array([1, 2, 3])
37+
38+
# Schema
39+
schema = pa.schema(
40+
[
41+
pa.field("A", pa.string(), nullable=True),
42+
]
43+
)
44+
df = read_json(test_data_path, schema=schema)
45+
result = df.collect()
46+
47+
assert result[0].column(0) == pa.array(["a", "b", "c"])
48+
assert result[0].schema == schema
49+
50+
# File extension
51+
test_data_path = os.path.join(path, "data_test_context", "data.json")
52+
df = read_json(test_data_path, file_extension=".json")
53+
result = df.collect()
54+
55+
assert result[0].column(0) == pa.array(["a", "b", "c"])
56+
assert result[0].column(1) == pa.array([1, 2, 3])
57+
58+
59+
def test_read_parquet_global():
60+
parquet_df = read_parquet(path="parquet/data/alltypes_plain.parquet")
61+
parquet_df.show()
62+
assert parquet_df is not None
63+
64+
path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
65+
parquet_df = read_parquet(path=path)
66+
assert parquet_df is not None
67+
68+
69+
def test_read_csv():
70+
csv_df = read_csv(path="testing/data/csv/aggregate_test_100.csv")
71+
csv_df.select(column("c1")).show()
72+
73+
74+
def test_read_csv_list():
75+
csv_df = read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
76+
expected = csv_df.count() * 2
77+
78+
double_csv_df = read_csv(
79+
path=[
80+
"testing/data/csv/aggregate_test_100.csv",
81+
"testing/data/csv/aggregate_test_100.csv",
82+
]
83+
)
84+
actual = double_csv_df.count()
85+
86+
double_csv_df.select(column("c1")).show()
87+
assert actual == expected
88+
89+
90+
def test_read_avro():
91+
avro_df = read_avro(path="testing/data/avro/alltypes_plain.avro")
92+
avro_df.show()
93+
assert avro_df is not None
94+
95+
path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
96+
avro_df = read_avro(path=path)
97+
assert avro_df is not None

python/tests/test_wrapper_coverage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
3434
return
3535

3636
for attr in dir(internal_obj):
37+
if attr in ["_global_ctx"]:
38+
continue
3739
assert attr in dir(wrapped_obj)
3840

3941
internal_attr = getattr(internal_obj, attr)

src/context.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::store::StorageContexts;
4343
use crate::udaf::PyAggregateUDF;
4444
use crate::udf::PyScalarUDF;
4545
use crate::udwf::PyWindowUDF;
46-
use crate::utils::{get_tokio_runtime, wait_for_future};
46+
use crate::utils::{get_global_ctx, get_tokio_runtime, wait_for_future};
4747
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4848
use datafusion::arrow::pyarrow::PyArrowType;
4949
use datafusion::arrow::record_batch::RecordBatch;
@@ -68,7 +68,7 @@ use datafusion::prelude::{
6868
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
6969
};
7070
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
71-
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
71+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7272
use tokio::task::JoinHandle;
7373

7474
/// Configuration options for a SessionContext
@@ -299,6 +299,16 @@ impl PySessionContext {
299299
})
300300
}
301301

302+
#[classmethod]
303+
#[pyo3(signature = ())]
304+
fn _global_ctx(
305+
_cls: &Bound<'_, PyType>,
306+
) -> PyResult<Self> {
307+
Ok(Self {
308+
ctx: get_global_ctx().clone(),
309+
})
310+
}
311+
302312
/// Register an object store with the given name
303313
#[pyo3(signature = (scheme, store, host=None))]
304314
pub fn register_object_store(

src/utils.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use pyo3::prelude::*;
2222
use std::future::Future;
2323
use std::sync::OnceLock;
2424
use tokio::runtime::Runtime;
25+
use datafusion::execution::context::SessionContext;
2526

2627
/// Utility to get the Tokio Runtime from Python
2728
#[inline]
@@ -35,6 +36,13 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
3536
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
3637
}
3738

39+
/// Utility to get the Global Datafussion CTX
40+
#[inline]
41+
pub(crate) fn get_global_ctx() -> &'static SessionContext {
42+
static RUNTIME: OnceLock<SessionContext> = OnceLock::new();
43+
RUNTIME.get_or_init(|| SessionContext::new())
44+
}
45+
3846
/// Utility to collect rust futures with GIL released
3947
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
4048
where

0 commit comments

Comments
 (0)