Skip to content

Commit 404e542

Browse files
ion-elgrecotimsaucer
authored andcommitted
feat: reads using global ctx
1 parent 3584bec commit 404e542

File tree

6 files changed

+303
-2
lines changed

6 files changed

+303
-2
lines changed

python/datafusion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Expr,
4646
WindowFrame,
4747
)
48+
from .io import read_avro, read_csv, read_json, read_parquet
4849
from .plan import ExecutionPlan, LogicalPlan
4950
from .record_batch import RecordBatch, RecordBatchStream
5051
from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF
@@ -81,6 +82,10 @@
8182
"functions",
8283
"object_store",
8384
"substrait",
85+
"read_parquet",
86+
"read_avro",
87+
"read_csv",
88+
"read_json",
8489
]
8590

8691

python/datafusion/io.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""IO read functions using global context."""
19+
20+
import pathlib
21+
22+
import pyarrow
23+
24+
from datafusion.dataframe import DataFrame
25+
from datafusion.expr import Expr
26+
27+
from ._internal import SessionContext as SessionContextInternal
28+
29+
30+
def read_parquet(
31+
path: str | pathlib.Path,
32+
table_partition_cols: list[tuple[str, str]] | None = None,
33+
parquet_pruning: bool = True,
34+
file_extension: str = ".parquet",
35+
skip_metadata: bool = True,
36+
schema: pyarrow.Schema | None = None,
37+
file_sort_order: list[list[Expr]] | None = None,
38+
) -> DataFrame:
39+
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
40+
41+
Args:
42+
path: Path to the Parquet file.
43+
table_partition_cols: Partition columns.
44+
parquet_pruning: Whether the parquet reader should use the predicate
45+
to prune row groups.
46+
file_extension: File extension; only files with this extension are
47+
selected for data input.
48+
skip_metadata: Whether the parquet reader should skip any metadata
49+
that may be in the file schema. This can help avoid schema
50+
conflicts due to metadata.
51+
schema: An optional schema representing the parquet files. If None,
52+
the parquet reader will try to infer it based on data in the
53+
file.
54+
file_sort_order: Sort order for the file.
55+
56+
Returns:
57+
DataFrame representation of the read Parquet files
58+
"""
59+
if table_partition_cols is None:
60+
table_partition_cols = []
61+
return DataFrame(
62+
SessionContextInternal._global_ctx().read_parquet(
63+
str(path),
64+
table_partition_cols,
65+
parquet_pruning,
66+
file_extension,
67+
skip_metadata,
68+
schema,
69+
file_sort_order,
70+
)
71+
)
72+
73+
74+
def read_json(
75+
path: str | pathlib.Path,
76+
schema: pyarrow.Schema | None = None,
77+
schema_infer_max_records: int = 1000,
78+
file_extension: str = ".json",
79+
table_partition_cols: list[tuple[str, str]] | None = None,
80+
file_compression_type: str | None = None,
81+
) -> DataFrame:
82+
"""Read a line-delimited JSON data source.
83+
84+
Args:
85+
path: Path to the JSON file.
86+
schema: The data source schema.
87+
schema_infer_max_records: Maximum number of rows to read from JSON
88+
files for schema inference if needed.
89+
file_extension: File extension; only files with this extension are
90+
selected for data input.
91+
table_partition_cols: Partition columns.
92+
file_compression_type: File compression type.
93+
94+
Returns:
95+
DataFrame representation of the read JSON files.
96+
"""
97+
if table_partition_cols is None:
98+
table_partition_cols = []
99+
return DataFrame(
100+
SessionContextInternal._global_ctx().read_json(
101+
str(path),
102+
schema,
103+
schema_infer_max_records,
104+
file_extension,
105+
table_partition_cols,
106+
file_compression_type,
107+
)
108+
)
109+
110+
111+
def read_csv(
112+
path: str | pathlib.Path | list[str] | list[pathlib.Path],
113+
schema: pyarrow.Schema | None = None,
114+
has_header: bool = True,
115+
delimiter: str = ",",
116+
schema_infer_max_records: int = 1000,
117+
file_extension: str = ".csv",
118+
table_partition_cols: list[tuple[str, str]] | None = None,
119+
file_compression_type: str | None = None,
120+
) -> DataFrame:
121+
"""Read a CSV data source.
122+
123+
Args:
124+
path: Path to the CSV file
125+
schema: An optional schema representing the CSV files. If None, the
126+
CSV reader will try to infer it based on data in file.
127+
has_header: Whether the CSV file have a header. If schema inference
128+
is run on a file with no headers, default column names are
129+
created.
130+
delimiter: An optional column delimiter.
131+
schema_infer_max_records: Maximum number of rows to read from CSV
132+
files for schema inference if needed.
133+
file_extension: File extension; only files with this extension are
134+
selected for data input.
135+
table_partition_cols: Partition columns.
136+
file_compression_type: File compression type.
137+
138+
Returns:
139+
DataFrame representation of the read CSV files
140+
"""
141+
if table_partition_cols is None:
142+
table_partition_cols = []
143+
144+
path = [str(p) for p in path] if isinstance(path, list) else str(path)
145+
146+
return DataFrame(
147+
SessionContextInternal._global_ctx().read_csv(
148+
path,
149+
schema,
150+
has_header,
151+
delimiter,
152+
schema_infer_max_records,
153+
file_extension,
154+
table_partition_cols,
155+
file_compression_type,
156+
)
157+
)
158+
159+
160+
def read_avro(
161+
path: str | pathlib.Path,
162+
schema: pyarrow.Schema | None = None,
163+
file_partition_cols: list[tuple[str, str]] | None = None,
164+
file_extension: str = ".avro",
165+
) -> DataFrame:
166+
"""Create a :py:class:`DataFrame` for reading Avro data source.
167+
168+
Args:
169+
path: Path to the Avro file.
170+
schema: The data source schema.
171+
file_partition_cols: Partition columns.
172+
file_extension: File extension to select.
173+
174+
Returns:
175+
DataFrame representation of the read Avro file
176+
"""
177+
if file_partition_cols is None:
178+
file_partition_cols = []
179+
return DataFrame(
180+
SessionContextInternal._global_ctx().read_avro(
181+
str(path), schema, file_partition_cols, file_extension
182+
)
183+
)

python/tests/test_io.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import os
18+
import pathlib
19+
20+
import pyarrow as pa
21+
from datafusion import column
22+
from datafusion.io import read_avro, read_csv, read_json, read_parquet
23+
24+
25+
def test_read_json_global_ctx(ctx):
26+
path = os.path.dirname(os.path.abspath(__file__))
27+
28+
# Default
29+
test_data_path = os.path.join(path, "data_test_context", "data.json")
30+
df = read_json(test_data_path)
31+
result = df.collect()
32+
33+
assert result[0].column(0) == pa.array(["a", "b", "c"])
34+
assert result[0].column(1) == pa.array([1, 2, 3])
35+
36+
# Schema
37+
schema = pa.schema(
38+
[
39+
pa.field("A", pa.string(), nullable=True),
40+
]
41+
)
42+
df = read_json(test_data_path, schema=schema)
43+
result = df.collect()
44+
45+
assert result[0].column(0) == pa.array(["a", "b", "c"])
46+
assert result[0].schema == schema
47+
48+
# File extension
49+
test_data_path = os.path.join(path, "data_test_context", "data.json")
50+
df = read_json(test_data_path, file_extension=".json")
51+
result = df.collect()
52+
53+
assert result[0].column(0) == pa.array(["a", "b", "c"])
54+
assert result[0].column(1) == pa.array([1, 2, 3])
55+
56+
57+
def test_read_parquet_global():
58+
parquet_df = read_parquet(path="parquet/data/alltypes_plain.parquet")
59+
parquet_df.show()
60+
assert parquet_df is not None
61+
62+
path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
63+
parquet_df = read_parquet(path=path)
64+
assert parquet_df is not None
65+
66+
67+
def test_read_csv():
68+
csv_df = read_csv(path="testing/data/csv/aggregate_test_100.csv")
69+
csv_df.select(column("c1")).show()
70+
71+
72+
def test_read_csv_list():
73+
csv_df = read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
74+
expected = csv_df.count() * 2
75+
76+
double_csv_df = read_csv(
77+
path=[
78+
"testing/data/csv/aggregate_test_100.csv",
79+
"testing/data/csv/aggregate_test_100.csv",
80+
]
81+
)
82+
actual = double_csv_df.count()
83+
84+
double_csv_df.select(column("c1")).show()
85+
assert actual == expected
86+
87+
88+
def test_read_avro():
89+
avro_df = read_avro(path="testing/data/avro/alltypes_plain.avro")
90+
avro_df.show()
91+
assert avro_df is not None
92+
93+
path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
94+
avro_df = read_avro(path=path)
95+
assert avro_df is not None

python/tests/test_wrapper_coverage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
3434
return
3535

3636
for attr in dir(internal_obj):
37+
if attr in ["_global_ctx"]:
38+
continue
3739
assert attr in dir(wrapped_obj)
3840

3941
internal_attr = getattr(internal_obj, attr)

src/context.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::store::StorageContexts;
4444
use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udwf::PyWindowUDF;
47-
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
47+
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
4848
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4949
use datafusion::arrow::pyarrow::PyArrowType;
5050
use datafusion::arrow::record_batch::RecordBatch;
@@ -69,7 +69,7 @@ use datafusion::prelude::{
6969
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7070
};
7171
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
72-
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
72+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7373
use tokio::task::JoinHandle;
7474

7575
/// Configuration options for a SessionContext
@@ -306,6 +306,14 @@ impl PySessionContext {
306306
})
307307
}
308308

309+
#[classmethod]
310+
#[pyo3(signature = ())]
311+
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
312+
Ok(Self {
313+
ctx: get_global_ctx().clone(),
314+
})
315+
}
316+
309317
/// Register an object store with the given name
310318
#[pyo3(signature = (scheme, store, host=None))]
311319
pub fn register_object_store(

src/utils.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use crate::errors::{PyDataFusionError, PyDataFusionResult};
1919
use crate::TokioRuntime;
20+
use datafusion::execution::context::SessionContext;
2021
use datafusion::logical_expr::Volatility;
2122
use pyo3::exceptions::PyValueError;
2223
use pyo3::prelude::*;
@@ -37,6 +38,13 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
3738
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
3839
}
3940

41+
/// Utility to get the Global Datafussion CTX
42+
#[inline]
43+
pub(crate) fn get_global_ctx() -> &'static SessionContext {
44+
static CTX: OnceLock<SessionContext> = OnceLock::new();
45+
CTX.get_or_init(|| SessionContext::new())
46+
}
47+
4048
/// Utility to collect rust futures with GIL released
4149
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
4250
where

0 commit comments

Comments
 (0)