diff --git a/Cargo.lock b/Cargo.lock index 5a74a4839..c6590fd21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,7 +79,7 @@ checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "const-random", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -449,9 +449,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", @@ -576,9 +576,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "4.0.1" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" +checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -586,9 +586,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "byteorder" @@ -635,9 +635,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.10" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" +checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf" dependencies = [ "jobserver", "libc", @@ -692,9 +692,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" dependencies = [ "cc", ] @@ -725,7 +725,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom", + "getrandom 0.2.15", "once_cell", "tiny-keccak", ] @@ -784,9 +784,9 @@ checksum = "69f3b219d28b6e3b4ac87bc1fc522e0803ab22e055da177bff0068c4150c61a6" [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -817,9 +817,9 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -961,7 +961,6 @@ dependencies = [ "object_store", "parquet", "paste", - "pyo3", "recursive", "sqlparser", "tokio", @@ -1411,9 +1410,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "either" @@ -1607,10 +1606,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets", +] + [[package]] name = "gimli" version = "0.31.1" @@ -1722,9 +1733,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "humantime" @@ -1734,9 +1745,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", @@ -1953,9 +1964,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -1975,9 +1986,9 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "itertools" @@ -2243,7 +2254,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -2377,9 +2388,9 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "ordered-float" @@ -2661,9 +2672,9 @@ dependencies = [ [[package]] name = "protobuf-src" -version = "2.1.0+27.1" +version = "2.1.1+27.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7edafa3bcc668fa93efafcbdf58d7821bbda0f4b458ac7fae3d57ec0fec8167" +checksum = "6217c3504da19b85a3a4b2e9a5183d635822d83507ba0986624b5c05b83bfc40" dependencies = [ "cmake", ] @@ -2794,7 +2805,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", - "getrandom", + "getrandom 0.2.15", "rand", "ring", "rustc-hash", @@ -2857,7 +2868,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -2926,9 +2937,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "regress" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f56e622c2378013c6c61e2bd776604c46dc1087b2dc5293275a0c20a44f0771" +checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366" dependencies = [ "hashbrown 0.15.2", "memchr", @@ -2997,7 +3008,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -3033,9 +3044,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.8.0", "errno", @@ -3046,9 +3057,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7" dependencies = [ "once_cell", "ring", @@ -3081,9 +3092,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" dependencies = [ "web-time", ] @@ -3107,9 +3118,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "same-file" @@ -3184,9 +3195,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" dependencies = [ "serde", ] @@ -3239,9 +3250,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.136" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "336a0c23cf42a38d9eaa7cd22c7040d04e1228a19a933890805ffd00a16437d2" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -3514,13 +3525,13 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -3831,9 +3842,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-segmentation" @@ -3890,11 +3901,11 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ - "getrandom", + "getrandom 0.2.15", "serde", ] @@ -3929,6 +3940,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -4185,6 +4205,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "write16" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index 10cffccb1..003ba36e5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ tokio = { version = "1.42", features = ["macros", "rt", "rt-multi-thread", "sync pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] } pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]} arrow = { version = "53", features = ["pyarrow"] } -datafusion = { version = "44.0.0", features = ["pyarrow", "avro", "unicode_expressions"] } +datafusion = { version = "44.0.0", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "44.0.0", optional = true } datafusion-proto = { version = "44.0.0" } datafusion-ffi = { version = "44.0.0" } diff --git a/docs/source/contributor-guide/introduction.rst b/docs/source/contributor-guide/introduction.rst index fb98cfd1d..25f2c21a4 100644 --- a/docs/source/contributor-guide/introduction.rst +++ b/docs/source/contributor-guide/introduction.rst @@ -95,3 +95,56 @@ To update dependencies, run .. code-block:: shell uv sync --dev --no-install-package datafusion + +Improving Build Speed +--------------------- + +The `pyo3 `_ dependency of this project contains a ``build.rs`` file which +can cause it to rebuild frequently. You can prevent this from happening by defining a ``PYO3_CONFIG_FILE`` +environment variable that points to a file with your build configuration. Whenever your build configuration +changes, such as during some major version updates, you will need to regenerate this file. This variable +should point to a fully resolved path on your build machine. + +To generate this file, use the following command: + +.. code-block:: shell + + PYO3_PRINT_CONFIG=1 cargo build + +This will generate some output that looks like the following. You will want to copy these contents intro +a file. If you place this file in your project directory with filename ``.pyo3_build_config`` it will +be ignored by ``git``. + +.. code-block:: + + implementation=CPython + version=3.8 + shared=true + abi3=true + lib_name=python3.12 + lib_dir=/opt/homebrew/opt/python@3.12/Frameworks/Python.framework/Versions/3.12/lib + executable=/Users/myusername/src/datafusion-python/.venv/bin/python + pointer_width=64 + build_flags= + suppress_build_script_link_lines=false + +Add the environment variable to your system. + +.. code-block:: shell + + export PYO3_CONFIG_FILE="/Users//myusername/src/datafusion-python/.pyo3_build_config" + +If you are on a Mac and you use VS Code for your IDE, you will want to add these variables +to your settings. You can find the appropriate rust flags by looking in the +``.cargo/config.toml`` file. + +.. code-block:: + + "rust-analyzer.cargo.extraEnv": { + "RUSTFLAGS": "-C link-arg=-undefined -C link-arg=dynamic_lookup", + "PYO3_CONFIG_FILE": "/Users/myusername/src/datafusion-python/.pyo3_build_config" + }, + "rust-analyzer.runnables.extraEnv": { + "RUSTFLAGS": "-C link-arg=-undefined -C link-arg=dynamic_lookup", + "PYO3_CONFIG_FILE": "/Users/myusername/src/personal/datafusion-python/.pyo3_build_config" + } diff --git a/python/tests/test_indexing.py b/python/tests/test_indexing.py index 5b0d08610..327decd2f 100644 --- a/python/tests/test_indexing.py +++ b/python/tests/test_indexing.py @@ -43,7 +43,8 @@ def test_err(df): with pytest.raises(Exception) as e_info: df["c"] - assert "Schema error: No field named c." in e_info.value.args[0] + for e in ["SchemaError", "FieldNotFound", 'name: "c"']: + assert e in e_info.value.args[0] with pytest.raises(Exception) as e_info: df[1] diff --git a/src/catalog.rs b/src/catalog.rs index 1ce66a4dc..1e189a5aa 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; -use crate::errors::DataFusionError; +use crate::errors::{PyDataFusionError, PyDataFusionResult}; use crate::utils::wait_for_future; use datafusion::{ arrow::pyarrow::ToPyArrow, @@ -96,11 +96,13 @@ impl PyDatabase { self.database.table_names().into_iter().collect() } - fn table(&self, name: &str, py: Python) -> PyResult { + fn table(&self, name: &str, py: Python) -> PyDataFusionResult { if let Some(table) = wait_for_future(py, self.database.table(name))? { Ok(PyTable::new(table)) } else { - Err(DataFusionError::Common(format!("Table not found: {name}")).into()) + Err(PyDataFusionError::Common(format!( + "Table not found: {name}" + ))) } } diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 7f9c75bfd..f5f8a6b06 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -23,6 +23,20 @@ use pyo3::{exceptions::PyValueError, prelude::*}; use crate::errors::py_datafusion_err; +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct PyScalarValue(pub ScalarValue); + +impl From for PyScalarValue { + fn from(value: ScalarValue) -> Self { + Self(value) + } +} +impl From for ScalarValue { + fn from(value: PyScalarValue) -> Self { + value.0 + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")] pub enum RexType { diff --git a/src/config.rs b/src/config.rs index 3f2a05580..cc725b9a3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,6 +21,8 @@ use pyo3::types::*; use datafusion::common::ScalarValue; use datafusion::config::ConfigOptions; +use crate::errors::PyDataFusionResult; + #[pyclass(name = "Config", module = "datafusion", subclass)] #[derive(Clone)] pub(crate) struct PyConfig { @@ -38,7 +40,7 @@ impl PyConfig { /// Get configurations from environment variables #[staticmethod] - pub fn from_env() -> PyResult { + pub fn from_env() -> PyDataFusionResult { Ok(Self { config: ConfigOptions::from_env()?, }) @@ -56,11 +58,10 @@ impl PyConfig { } /// Set a configuration option - pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyResult<()> { + pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { let scalar_value = py_obj_to_scalar_value(py, value); - self.config - .set(key, scalar_value.to_string().as_str()) - .map_err(|e| e.into()) + self.config.set(key, scalar_value.to_string().as_str())?; + Ok(()) } /// Get all configuration options diff --git a/src/context.rs b/src/context.rs index bab7fd42a..f53b15576 100644 --- a/src/context.rs +++ b/src/context.rs @@ -28,16 +28,17 @@ use object_store::ObjectStore; use url::Url; use uuid::Uuid; -use pyo3::exceptions::{PyKeyError, PyNotImplementedError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, DataFusionError}; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; +use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; use crate::store::StorageContexts; use crate::udaf::PyAggregateUDF; @@ -277,7 +278,7 @@ impl PySessionContext { pub fn new( config: Option, runtime: Option, - ) -> PyResult { + ) -> PyDataFusionResult { let config = if let Some(c) = config { c.config } else { @@ -348,7 +349,7 @@ impl PySessionContext { schema: Option>, file_sort_order: Option>>, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let options = ListingOptions::new(Arc::new(ParquetFormat::new())) .with_file_extension(file_extension) .with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?) @@ -365,7 +366,7 @@ impl PySessionContext { None => { let state = self.ctx.state(); let schema = options.infer_schema(&state, &table_path); - wait_for_future(py, schema).map_err(DataFusionError::from)? + wait_for_future(py, schema)? } }; let config = ListingTableConfig::new(table_path) @@ -382,9 +383,9 @@ impl PySessionContext { } /// Returns a PyDataFrame whose plan corresponds to the SQL statement. - pub fn sql(&mut self, query: &str, py: Python) -> PyResult { + pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult { let result = self.ctx.sql(query); - let df = wait_for_future(py, result).map_err(DataFusionError::from)?; + let df = wait_for_future(py, result)?; Ok(PyDataFrame::new(df)) } @@ -394,14 +395,14 @@ impl PySessionContext { query: &str, options: Option, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let options = if let Some(options) = options { options.options } else { SQLOptions::new() }; let result = self.ctx.sql_with_options(query, options); - let df = wait_for_future(py, result).map_err(DataFusionError::from)?; + let df = wait_for_future(py, result)?; Ok(PyDataFrame::new(df)) } @@ -412,14 +413,14 @@ impl PySessionContext { name: Option<&str>, schema: Option>, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let schema = if let Some(schema) = schema { SchemaRef::from(schema.0) } else { partitions.0[0][0].schema() }; - let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?; + let table = MemTable::try_new(schema, partitions.0)?; // generate a random (unique) name for this table if none is provided // table name cannot start with numeric digit @@ -433,11 +434,9 @@ impl PySessionContext { } }; - self.ctx - .register_table(&*table_name, Arc::new(table)) - .map_err(DataFusionError::from)?; + self.ctx.register_table(&*table_name, Arc::new(table))?; - let table = wait_for_future(py, self._table(&table_name)).map_err(DataFusionError::from)?; + let table = wait_for_future(py, self._table(&table_name))?; let df = PyDataFrame::new(table); Ok(df) @@ -495,15 +494,14 @@ impl PySessionContext { data: Bound<'_, PyAny>, name: Option<&str>, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let (schema, batches) = if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) { // Works for any object that implements __arrow_c_stream__ in pycapsule. let schema = stream_reader.schema().as_ref().to_owned(); let batches = stream_reader - .collect::, arrow::error::ArrowError>>() - .map_err(DataFusionError::from)?; + .collect::, arrow::error::ArrowError>>()?; (schema, batches) } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) { @@ -512,8 +510,8 @@ impl PySessionContext { (array.schema().as_ref().to_owned(), vec![array]) } else { - return Err(PyTypeError::new_err( - "Expected either a Arrow Array or Arrow Stream in from_arrow().", + return Err(crate::errors::PyDataFusionError::Common( + "Expected either a Arrow Array or Arrow Stream in from_arrow().".to_string(), )); }; @@ -559,17 +557,13 @@ impl PySessionContext { Ok(df) } - pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> { - self.ctx - .register_table(name, table.table()) - .map_err(DataFusionError::from)?; + pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyDataFusionResult<()> { + self.ctx.register_table(name, table.table())?; Ok(()) } - pub fn deregister_table(&mut self, name: &str) -> PyResult<()> { - self.ctx - .deregister_table(name) - .map_err(DataFusionError::from)?; + pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> { + self.ctx.deregister_table(name)?; Ok(()) } @@ -578,10 +572,10 @@ impl PySessionContext { &mut self, name: &str, provider: Bound<'_, PyAny>, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { if provider.hasattr("__datafusion_table_provider__")? { let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; - let capsule = capsule.downcast::()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; validate_pycapsule(capsule, "datafusion_table_provider")?; let provider = unsafe { capsule.reference::() }; @@ -591,8 +585,9 @@ impl PySessionContext { Ok(()) } else { - Err(PyNotImplementedError::new_err( - "__datafusion_table_provider__ does not exist on Table Provider object.", + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_table_provider__ does not exist on Table Provider object." + .to_string(), )) } } @@ -601,12 +596,10 @@ impl PySessionContext { &mut self, name: &str, partitions: PyArrowType>>, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let schema = partitions.0[0][0].schema(); let table = MemTable::try_new(schema, partitions.0)?; - self.ctx - .register_table(name, Arc::new(table)) - .map_err(DataFusionError::from)?; + self.ctx.register_table(name, Arc::new(table))?; Ok(()) } @@ -628,7 +621,7 @@ impl PySessionContext { schema: Option>, file_sort_order: Option>>, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let mut options = ParquetReadOptions::default() .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) .parquet_pruning(parquet_pruning) @@ -642,7 +635,7 @@ impl PySessionContext { .collect(); let result = self.ctx.register_parquet(name, path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + wait_for_future(py, result)?; Ok(()) } @@ -666,12 +659,12 @@ impl PySessionContext { file_extension: &str, file_compression_type: Option, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { - return Err(PyValueError::new_err( + return Err(crate::errors::PyDataFusionError::PythonError(py_value_err( "Delimiter must be a single character", - )); + ))); } let mut options = CsvReadOptions::new() @@ -685,11 +678,11 @@ impl PySessionContext { if path.is_instance_of::() { let paths = path.extract::>()?; let result = self.register_csv_from_multiple_paths(name, paths, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + wait_for_future(py, result)?; } else { let path = path.extract::()?; let result = self.ctx.register_csv(name, &path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + wait_for_future(py, result)?; } Ok(()) @@ -713,7 +706,7 @@ impl PySessionContext { table_partition_cols: Vec<(String, String)>, file_compression_type: Option, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; @@ -726,7 +719,7 @@ impl PySessionContext { options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_json(name, path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + wait_for_future(py, result)?; Ok(()) } @@ -745,7 +738,7 @@ impl PySessionContext { file_extension: &str, table_partition_cols: Vec<(String, String)>, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; @@ -756,7 +749,7 @@ impl PySessionContext { options.schema = schema.as_ref().map(|x| &x.0); let result = self.ctx.register_avro(name, path, options); - wait_for_future(py, result).map_err(DataFusionError::from)?; + wait_for_future(py, result)?; Ok(()) } @@ -767,12 +760,10 @@ impl PySessionContext { name: &str, dataset: &Bound<'_, PyAny>, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { let table: Arc = Arc::new(Dataset::new(dataset, py)?); - self.ctx - .register_table(name, table) - .map_err(DataFusionError::from)?; + self.ctx.register_table(name, table)?; Ok(()) } @@ -824,11 +815,11 @@ impl PySessionContext { Ok(PyDataFrame::new(x)) } - pub fn table_exist(&self, name: &str) -> PyResult { + pub fn table_exist(&self, name: &str) -> PyDataFusionResult { Ok(self.ctx.table_exist(name)?) } - pub fn empty_table(&self) -> PyResult { + pub fn empty_table(&self) -> PyDataFusionResult { Ok(PyDataFrame::new(self.ctx.read_empty()?)) } @@ -847,7 +838,7 @@ impl PySessionContext { table_partition_cols: Vec<(String, String)>, file_compression_type: Option, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let path = path .to_str() .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; @@ -859,10 +850,10 @@ impl PySessionContext { let df = if let Some(schema) = schema { options.schema = Some(&schema.0); let result = self.ctx.read_json(path, options); - wait_for_future(py, result).map_err(DataFusionError::from)? + wait_for_future(py, result)? } else { let result = self.ctx.read_json(path, options); - wait_for_future(py, result).map_err(DataFusionError::from)? + wait_for_future(py, result)? }; Ok(PyDataFrame::new(df)) } @@ -888,12 +879,12 @@ impl PySessionContext { table_partition_cols: Vec<(String, String)>, file_compression_type: Option, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { - return Err(PyValueError::new_err( + return Err(crate::errors::PyDataFusionError::PythonError(py_value_err( "Delimiter must be a single character", - )); + ))); }; let mut options = CsvReadOptions::new() @@ -909,12 +900,12 @@ impl PySessionContext { let paths = path.extract::>()?; let paths = paths.iter().map(|p| p as &str).collect::>(); let result = self.ctx.read_csv(paths, options); - let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); + let df = PyDataFrame::new(wait_for_future(py, result)?); Ok(df) } else { let path = path.extract::()?; let result = self.ctx.read_csv(path, options); - let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); + let df = PyDataFrame::new(wait_for_future(py, result)?); Ok(df) } } @@ -938,7 +929,7 @@ impl PySessionContext { schema: Option>, file_sort_order: Option>>, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let mut options = ParquetReadOptions::default() .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) .parquet_pruning(parquet_pruning) @@ -952,7 +943,7 @@ impl PySessionContext { .collect(); let result = self.ctx.read_parquet(path, options); - let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); + let df = PyDataFrame::new(wait_for_future(py, result)?); Ok(df) } @@ -965,26 +956,23 @@ impl PySessionContext { table_partition_cols: Vec<(String, String)>, file_extension: &str, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let mut options = AvroReadOptions::default() .table_partition_cols(convert_table_partition_cols(table_partition_cols)?); options.file_extension = file_extension; let df = if let Some(schema) = schema { options.schema = Some(&schema.0); let read_future = self.ctx.read_avro(path, options); - wait_for_future(py, read_future).map_err(DataFusionError::from)? + wait_for_future(py, read_future)? } else { let read_future = self.ctx.read_avro(path, options); - wait_for_future(py, read_future).map_err(DataFusionError::from)? + wait_for_future(py, read_future)? }; Ok(PyDataFrame::new(df)) } - pub fn read_table(&self, table: &PyTable) -> PyResult { - let df = self - .ctx - .read_table(table.table()) - .map_err(DataFusionError::from)?; + pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult { + let df = self.ctx.read_table(table.table())?; Ok(PyDataFrame::new(df)) } @@ -1011,7 +999,7 @@ impl PySessionContext { plan: PyExecutionPlan, part: usize, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let ctx: TaskContext = TaskContext::from(&self.ctx.state()); // create a Tokio runtime to run the async code let rt = &get_tokio_runtime().0; @@ -1071,13 +1059,13 @@ impl PySessionContext { pub fn convert_table_partition_cols( table_partition_cols: Vec<(String, String)>, -) -> Result, DataFusionError> { +) -> PyDataFusionResult> { table_partition_cols .into_iter() .map(|(name, ty)| match ty.as_str() { "string" => Ok((name, DataType::Utf8)), "int" => Ok((name, DataType::Int32)), - _ => Err(DataFusionError::Common(format!( + _ => Err(crate::errors::PyDataFusionError::Common(format!( "Unsupported data type '{ty}' for partition column. Supported types are 'string' and 'int'" ))), }) diff --git a/src/dataframe.rs b/src/dataframe.rs index b875480a7..6fb08ba25 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -33,20 +33,20 @@ use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::execution::SendableRecordBatchStream; use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel}; use datafusion::prelude::*; -use pyo3::exceptions::{PyTypeError, PyValueError}; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods}; use tokio::task::JoinHandle; -use crate::errors::py_datafusion_err; +use crate::errors::{py_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; use crate::sql::logical::PyLogicalPlan; use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future}; use crate::{ - errors::DataFusionError, + errors::PyDataFusionResult, expr::{sort_expr::PySortExpr, PyExpr}, }; @@ -69,7 +69,7 @@ impl PyDataFrame { #[pymethods] impl PyDataFrame { /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1, col2, col3]]` - fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult { + fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult { if let Ok(key) = key.extract::() { // df[col] self.select_columns(vec![key]) @@ -84,12 +84,12 @@ impl PyDataFrame { // df[[col1, col2, col3]] self.select_columns(keys) } else { - let message = "DataFrame can only be indexed by string index or indices"; - Err(PyTypeError::new_err(message)) + let message = "DataFrame can only be indexed by string index or indices".to_string(); + Err(PyDataFusionError::Common(message)) } } - fn __repr__(&self, py: Python) -> PyResult { + fn __repr__(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone().limit(0, Some(10))?; let batches = wait_for_future(py, df.collect())?; let batches_as_string = pretty::pretty_format_batches(&batches); @@ -99,7 +99,7 @@ impl PyDataFrame { } } - fn _repr_html_(&self, py: Python) -> PyResult { + fn _repr_html_(&self, py: Python) -> PyDataFusionResult { let mut html_str = "\n".to_string(); let df = self.df.as_ref().clone().limit(0, Some(10))?; @@ -145,7 +145,7 @@ impl PyDataFrame { } /// Calculate summary statistics for a DataFrame - fn describe(&self, py: Python) -> PyResult { + fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); let stat_df = wait_for_future(py, df.describe())?; Ok(Self::new(stat_df)) @@ -157,37 +157,37 @@ impl PyDataFrame { } #[pyo3(signature = (*args))] - fn select_columns(&self, args: Vec) -> PyResult { + fn select_columns(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().select_columns(&args)?; Ok(Self::new(df)) } #[pyo3(signature = (*args))] - fn select(&self, args: Vec) -> PyResult { + fn select(&self, args: Vec) -> PyDataFusionResult { let expr = args.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().select(expr)?; Ok(Self::new(df)) } #[pyo3(signature = (*args))] - fn drop(&self, args: Vec) -> PyResult { + fn drop(&self, args: Vec) -> PyDataFusionResult { let cols = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().drop_columns(&cols)?; Ok(Self::new(df)) } - fn filter(&self, predicate: PyExpr) -> PyResult { + fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; Ok(Self::new(df)) } - fn with_column(&self, name: &str, expr: PyExpr) -> PyResult { + fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().with_column(name, expr.into())?; Ok(Self::new(df)) } - fn with_columns(&self, exprs: Vec) -> PyResult { + fn with_columns(&self, exprs: Vec) -> PyDataFusionResult { let mut df = self.df.as_ref().clone(); for expr in exprs { let expr: Expr = expr.into(); @@ -199,7 +199,7 @@ impl PyDataFrame { /// Rename one column by applying a new projection. This is a no-op if the column to be /// renamed does not exist. - fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyResult { + fn with_column_renamed(&self, old_name: &str, new_name: &str) -> PyDataFusionResult { let df = self .df .as_ref() @@ -208,7 +208,7 @@ impl PyDataFrame { Ok(Self::new(df)) } - fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyResult { + fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyDataFusionResult { let group_by = group_by.into_iter().map(|e| e.into()).collect(); let aggs = aggs.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().aggregate(group_by, aggs)?; @@ -216,14 +216,14 @@ impl PyDataFrame { } #[pyo3(signature = (*exprs))] - fn sort(&self, exprs: Vec) -> PyResult { + fn sort(&self, exprs: Vec) -> PyDataFusionResult { let exprs = to_sort_expressions(exprs); let df = self.df.as_ref().clone().sort(exprs)?; Ok(Self::new(df)) } #[pyo3(signature = (count, offset=0))] - fn limit(&self, count: usize, offset: usize) -> PyResult { + fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult { let df = self.df.as_ref().clone().limit(offset, Some(count))?; Ok(Self::new(df)) } @@ -232,14 +232,15 @@ impl PyDataFrame { /// Unless some order is specified in the plan, there is no /// guarantee of the order of the result. fn collect(&self, py: Python) -> PyResult> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect())?; + let batches = wait_for_future(py, self.df.as_ref().clone().collect()) + .map_err(PyDataFusionError::from)?; // cannot use PyResult> return type due to // https://github.com/PyO3/pyo3/issues/1813 batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect() } /// Cache DataFrame. - fn cache(&self, py: Python) -> PyResult { + fn cache(&self, py: Python) -> PyDataFusionResult { let df = wait_for_future(py, self.df.as_ref().clone().cache())?; Ok(Self::new(df)) } @@ -247,7 +248,8 @@ impl PyDataFrame { /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch /// maintaining the input partitioning. fn collect_partitioned(&self, py: Python) -> PyResult>> { - let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned())?; + let batches = wait_for_future(py, self.df.as_ref().clone().collect_partitioned()) + .map_err(PyDataFusionError::from)?; batches .into_iter() @@ -257,13 +259,13 @@ impl PyDataFrame { /// Print the result, 20 lines by default #[pyo3(signature = (num=20))] - fn show(&self, py: Python, num: usize) -> PyResult<()> { + fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> { let df = self.df.as_ref().clone().limit(0, Some(num))?; print_dataframe(py, df) } /// Filter out duplicate rows - fn distinct(&self) -> PyResult { + fn distinct(&self) -> PyDataFusionResult { let df = self.df.as_ref().clone().distinct()?; Ok(Self::new(df)) } @@ -274,7 +276,7 @@ impl PyDataFrame { how: &str, left_on: Vec, right_on: Vec, - ) -> PyResult { + ) -> PyDataFusionResult { let join_type = match how { "inner" => JoinType::Inner, "left" => JoinType::Left, @@ -283,10 +285,9 @@ impl PyDataFrame { "semi" => JoinType::LeftSemi, "anti" => JoinType::LeftAnti, how => { - return Err(DataFusionError::Common(format!( + return Err(PyDataFusionError::Common(format!( "The join type {how} does not exist or is not implemented" - )) - .into()); + ))); } }; @@ -303,7 +304,12 @@ impl PyDataFrame { Ok(Self::new(df)) } - fn join_on(&self, right: PyDataFrame, on_exprs: Vec, how: &str) -> PyResult { + fn join_on( + &self, + right: PyDataFrame, + on_exprs: Vec, + how: &str, + ) -> PyDataFusionResult { let join_type = match how { "inner" => JoinType::Inner, "left" => JoinType::Left, @@ -312,10 +318,9 @@ impl PyDataFrame { "semi" => JoinType::LeftSemi, "anti" => JoinType::LeftAnti, how => { - return Err(DataFusionError::Common(format!( + return Err(PyDataFusionError::Common(format!( "The join type {how} does not exist or is not implemented" - )) - .into()); + ))); } }; let exprs: Vec = on_exprs.into_iter().map(|e| e.into()).collect(); @@ -330,7 +335,7 @@ impl PyDataFrame { /// Print the query plan #[pyo3(signature = (verbose=false, analyze=false))] - fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> { + fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyDataFusionResult<()> { let df = self.df.as_ref().clone().explain(verbose, analyze)?; print_dataframe(py, df) } @@ -341,18 +346,18 @@ impl PyDataFrame { } /// Get the optimized logical plan for this `DataFrame` - fn optimized_logical_plan(&self) -> PyResult { + fn optimized_logical_plan(&self) -> PyDataFusionResult { Ok(self.df.as_ref().clone().into_optimized_plan()?.into()) } /// Get the execution plan for this `DataFrame` - fn execution_plan(&self, py: Python) -> PyResult { + fn execution_plan(&self, py: Python) -> PyDataFusionResult { let plan = wait_for_future(py, self.df.as_ref().clone().create_physical_plan())?; Ok(plan.into()) } /// Repartition a `DataFrame` based on a logical partitioning scheme. - fn repartition(&self, num: usize) -> PyResult { + fn repartition(&self, num: usize) -> PyDataFusionResult { let new_df = self .df .as_ref() @@ -363,7 +368,7 @@ impl PyDataFrame { /// Repartition a `DataFrame` based on a logical partitioning scheme. #[pyo3(signature = (*args, num))] - fn repartition_by_hash(&self, args: Vec, num: usize) -> PyResult { + fn repartition_by_hash(&self, args: Vec, num: usize) -> PyDataFusionResult { let expr = args.into_iter().map(|py_expr| py_expr.into()).collect(); let new_df = self .df @@ -376,7 +381,7 @@ impl PyDataFrame { /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The /// two `DataFrame`s must have exactly the same schema #[pyo3(signature = (py_df, distinct=false))] - fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult { + fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyDataFusionResult { let new_df = if distinct { self.df .as_ref() @@ -391,7 +396,7 @@ impl PyDataFrame { /// Calculate the distinct union of two `DataFrame`s. The /// two `DataFrame`s must have exactly the same schema - fn union_distinct(&self, py_df: PyDataFrame) -> PyResult { + fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self .df .as_ref() @@ -401,7 +406,7 @@ impl PyDataFrame { } #[pyo3(signature = (column, preserve_nulls=true))] - fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyResult { + fn unnest_column(&self, column: &str, preserve_nulls: bool) -> PyDataFusionResult { // TODO: expose RecursionUnnestOptions // REF: https://github.com/apache/datafusion/pull/11577 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); @@ -414,7 +419,11 @@ impl PyDataFrame { } #[pyo3(signature = (columns, preserve_nulls=true))] - fn unnest_columns(&self, columns: Vec, preserve_nulls: bool) -> PyResult { + fn unnest_columns( + &self, + columns: Vec, + preserve_nulls: bool, + ) -> PyDataFusionResult { // TODO: expose RecursionUnnestOptions // REF: https://github.com/apache/datafusion/pull/11577 let unnest_options = UnnestOptions::default().with_preserve_nulls(preserve_nulls); @@ -428,7 +437,7 @@ impl PyDataFrame { } /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn intersect(&self, py_df: PyDataFrame) -> PyResult { + fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self .df .as_ref() @@ -438,13 +447,13 @@ impl PyDataFrame { } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema - fn except_all(&self, py_df: PyDataFrame) -> PyResult { + fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; Ok(Self::new(new_df)) } /// Write a `DataFrame` to a CSV file. - fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyResult<()> { + fn write_csv(&self, path: &str, with_header: bool, py: Python) -> PyDataFusionResult<()> { let csv_options = CsvOptions { has_header: Some(with_header), ..Default::default() @@ -472,7 +481,7 @@ impl PyDataFrame { compression: &str, compression_level: Option, py: Python, - ) -> PyResult<()> { + ) -> PyDataFusionResult<()> { fn verify_compression_level(cl: Option) -> Result { cl.ok_or(PyValueError::new_err("compression_level is not defined")) } @@ -496,7 +505,7 @@ impl PyDataFrame { "lz4_raw" => Compression::LZ4_RAW, "uncompressed" => Compression::UNCOMPRESSED, _ => { - return Err(PyValueError::new_err(format!( + return Err(PyDataFusionError::Common(format!( "Unrecognized compression type {compression}" ))); } @@ -522,7 +531,7 @@ impl PyDataFrame { } /// Executes a query and writes the results to a partitioned JSON file. - fn write_json(&self, path: &str, py: Python) -> PyResult<()> { + fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> { wait_for_future( py, self.df @@ -551,7 +560,7 @@ impl PyDataFrame { &'py mut self, py: Python<'py>, requested_schema: Option>, - ) -> PyResult> { + ) -> PyDataFusionResult> { let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?; let mut schema: Schema = self.df.schema().to_owned().into(); @@ -559,15 +568,14 @@ impl PyDataFrame { validate_pycapsule(&schema_capsule, "arrow_schema")?; let schema_ptr = unsafe { schema_capsule.reference::() }; - let desired_schema = Schema::try_from(schema_ptr).map_err(DataFusionError::from)?; + let desired_schema = Schema::try_from(schema_ptr)?; - schema = project_schema(schema, desired_schema).map_err(DataFusionError::ArrowError)?; + schema = project_schema(schema, desired_schema)?; batches = batches .into_iter() .map(|record_batch| record_batch_into_schema(record_batch, &schema)) - .collect::, ArrowError>>() - .map_err(DataFusionError::ArrowError)?; + .collect::, ArrowError>>()?; } let batches_wrapped = batches.into_iter().map(Ok); @@ -578,9 +586,10 @@ impl PyDataFrame { let ffi_stream = FFI_ArrowArrayStream::new(reader); let stream_capsule_name = CString::new("arrow_array_stream").unwrap(); PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name)) + .map_err(PyDataFusionError::from) } - fn execute_stream(&self, py: Python) -> PyResult { + fn execute_stream(&self, py: Python) -> PyDataFusionResult { // create a Tokio runtime to run the async code let rt = &get_tokio_runtime().0; let df = self.df.as_ref().clone(); @@ -647,13 +656,13 @@ impl PyDataFrame { } // Executes this DataFrame to get the total number of rows. - fn count(&self, py: Python) -> PyResult { + fn count(&self, py: Python) -> PyDataFusionResult { Ok(wait_for_future(py, self.df.as_ref().clone().count())?) } } /// Print DataFrame -fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> { +fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> { // Get string representation of record batches let batches = wait_for_future(py, df.collect())?; let batches_as_string = pretty::pretty_format_batches(&batches); diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs index 9d2559429..ace42115b 100644 --- a/src/dataset_exec.rs +++ b/src/dataset_exec.rs @@ -42,7 +42,7 @@ use datafusion::physical_plan::{ SendableRecordBatchStream, Statistics, }; -use crate::errors::DataFusionError; +use crate::errors::PyDataFusionResult; use crate::pyarrow_filter_expression::PyArrowFilterExpression; struct PyArrowBatchesAdapter { @@ -83,8 +83,8 @@ impl DatasetExec { dataset: &Bound<'_, PyAny>, projection: Option>, filters: &[Expr], - ) -> Result { - let columns: Option, DataFusionError>> = projection.map(|p| { + ) -> PyDataFusionResult { + let columns: Option>> = projection.map(|p| { p.iter() .map(|index| { let name: String = dataset diff --git a/src/errors.rs b/src/errors.rs index d12b6ade1..b02b754a2 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -24,10 +24,10 @@ use datafusion::error::DataFusionError as InnerDataFusionError; use prost::EncodeError; use pyo3::{exceptions::PyException, PyErr}; -pub type Result = std::result::Result; +pub type PyDataFusionResult = std::result::Result; #[derive(Debug)] -pub enum DataFusionError { +pub enum PyDataFusionError { ExecutionError(InnerDataFusionError), ArrowError(ArrowError), Common(String), @@ -35,46 +35,46 @@ pub enum DataFusionError { EncodeError(EncodeError), } -impl fmt::Display for DataFusionError { +impl fmt::Display for PyDataFusionError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {e:?}"), - DataFusionError::ArrowError(e) => write!(f, "Arrow error: {e:?}"), - DataFusionError::PythonError(e) => write!(f, "Python error {e:?}"), - DataFusionError::Common(e) => write!(f, "{e}"), - DataFusionError::EncodeError(e) => write!(f, "Failed to encode substrait plan: {e}"), + PyDataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {e:?}"), + PyDataFusionError::ArrowError(e) => write!(f, "Arrow error: {e:?}"), + PyDataFusionError::PythonError(e) => write!(f, "Python error {e:?}"), + PyDataFusionError::Common(e) => write!(f, "{e}"), + PyDataFusionError::EncodeError(e) => write!(f, "Failed to encode substrait plan: {e}"), } } } -impl From for DataFusionError { - fn from(err: ArrowError) -> DataFusionError { - DataFusionError::ArrowError(err) +impl From for PyDataFusionError { + fn from(err: ArrowError) -> PyDataFusionError { + PyDataFusionError::ArrowError(err) } } -impl From for DataFusionError { - fn from(err: InnerDataFusionError) -> DataFusionError { - DataFusionError::ExecutionError(err) +impl From for PyDataFusionError { + fn from(err: InnerDataFusionError) -> PyDataFusionError { + PyDataFusionError::ExecutionError(err) } } -impl From for DataFusionError { - fn from(err: PyErr) -> DataFusionError { - DataFusionError::PythonError(err) +impl From for PyDataFusionError { + fn from(err: PyErr) -> PyDataFusionError { + PyDataFusionError::PythonError(err) } } -impl From for PyErr { - fn from(err: DataFusionError) -> PyErr { +impl From for PyErr { + fn from(err: PyDataFusionError) -> PyErr { match err { - DataFusionError::PythonError(py_err) => py_err, + PyDataFusionError::PythonError(py_err) => py_err, _ => PyException::new_err(err.to_string()), } } } -impl Error for DataFusionError {} +impl Error for PyDataFusionError {} pub fn py_type_err(e: impl Debug) -> PyErr { PyErr::new::(format!("{e:?}")) diff --git a/src/expr.rs b/src/expr.rs index bca0cd3fa..1e9983d42 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -24,7 +24,6 @@ use std::convert::{From, Into}; use std::sync::Arc; use window::PyWindowFrame; -use arrow::pyarrow::ToPyArrow; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::functions::core::expr_ext::FieldAccessor; @@ -33,15 +32,17 @@ use datafusion::logical_expr::{ expr::{AggregateFunction, InList, InSubquery, ScalarFunction, WindowFunction}, lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast, }; -use datafusion::scalar::ScalarValue; -use crate::common::data_type::{DataTypeMap, NullTreatment, RexType}; -use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, DataFusionError}; +use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType}; +use crate::errors::{ + py_runtime_err, py_type_err, py_unsupported_variant_err, PyDataFusionError, PyDataFusionResult, +}; use crate::expr::aggregate_expr::PyAggregateFunction; use crate::expr::binary_expr::PyBinaryExpr; use crate::expr::column::PyColumn; use crate::expr::literal::PyLiteral; use crate::functions::add_builder_fns_to_window; +use crate::pyarrow_util::scalar_to_pyarrow; use crate::sql::logical::PyLogicalPlan; use self::alias::PyAlias; @@ -261,8 +262,8 @@ impl PyExpr { } #[staticmethod] - pub fn literal(value: ScalarValue) -> PyExpr { - lit(value).into() + pub fn literal(value: PyScalarValue) -> PyExpr { + lit(value.0).into() } #[staticmethod] @@ -356,7 +357,7 @@ impl PyExpr { /// Extracts the Expr value into a PyObject that can be shared with Python pub fn python_value(&self, py: Python) -> PyResult { match &self.expr { - Expr::Literal(scalar_value) => Ok(scalar_value.to_pyarrow(py)?), + Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py), _ => Err(py_type_err(format!( "Non Expr::Literal encountered in types: {:?}", &self.expr @@ -568,7 +569,7 @@ impl PyExpr { window_frame: Option, order_by: Option>, null_treatment: Option, - ) -> PyResult { + ) -> PyDataFusionResult { match &self.expr { Expr::AggregateFunction(agg_fn) => { let window_fn = Expr::WindowFunction(WindowFunction::new( @@ -592,10 +593,9 @@ impl PyExpr { null_treatment, ), _ => Err( - DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan( + PyDataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan( format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()), )) - .into(), ), } } @@ -649,34 +649,26 @@ impl PyExprFuncBuilder { .into() } - pub fn build(&self) -> PyResult { - self.builder - .clone() - .build() - .map(|expr| expr.into()) - .map_err(|err| err.into()) + pub fn build(&self) -> PyDataFusionResult { + Ok(self.builder.clone().build().map(|expr| expr.into())?) } } impl PyExpr { - pub fn _column_name(&self, plan: &LogicalPlan) -> Result { + pub fn _column_name(&self, plan: &LogicalPlan) -> PyDataFusionResult { let field = Self::expr_to_field(&self.expr, plan)?; Ok(field.name().to_owned()) } /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against - pub fn expr_to_field( - expr: &Expr, - input_plan: &LogicalPlan, - ) -> Result, DataFusionError> { + pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult> { match expr { Expr::Wildcard { .. } => { // Since * could be any of the valid column names just return the first one Ok(Arc::new(input_plan.schema().field(0).clone())) } _ => { - let fields = - exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?; + let fields = exprlist_to_fields(&[expr.clone()], input_plan)?; Ok(fields[0].1.clone()) } } diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index a8a885c54..fe3af2e25 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::expr::PyExpr; +use crate::{errors::PyDataFusionResult, expr::PyExpr}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; use pyo3::prelude::*; @@ -44,11 +44,11 @@ impl PyCaseBuilder { } } - fn otherwise(&mut self, else_expr: PyExpr) -> PyResult { + fn otherwise(&mut self, else_expr: PyExpr) -> PyDataFusionResult { Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into()) } - fn end(&mut self) -> PyResult { + fn end(&mut self) -> PyDataFusionResult { Ok(self.case_builder.end()?.clone().into()) } } diff --git a/src/expr/literal.rs b/src/expr/literal.rs index 43084ba4b..2cb2079f1 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::errors::DataFusionError; +use crate::errors::PyDataFusionError; use datafusion::common::ScalarValue; use pyo3::prelude::*; @@ -154,5 +154,5 @@ impl PyLiteral { } fn unexpected_literal_value(value: &ScalarValue) -> PyErr { - DataFusionError::Common(format!("getValue() - Unexpected value: {value}")).into() + PyDataFusionError::Common(format!("getValue() - Unexpected value: {value}")).into() } diff --git a/src/expr/window.rs b/src/expr/window.rs index 6486dbb32..4dc6cb9c9 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -21,8 +21,9 @@ use datafusion::logical_expr::{Expr, Window, WindowFrame, WindowFrameBound, Wind use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; +use crate::common::data_type::PyScalarValue; use crate::common::df_schema::PyDFSchema; -use crate::errors::py_type_err; +use crate::errors::{py_type_err, PyDataFusionResult}; use crate::expr::logical_node::LogicalNode; use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr}; use crate::expr::PyExpr; @@ -171,8 +172,8 @@ impl PyWindowFrame { #[pyo3(signature=(unit, start_bound, end_bound))] pub fn new( unit: &str, - start_bound: Option, - end_bound: Option, + start_bound: Option, + end_bound: Option, ) -> PyResult { let units = unit.to_ascii_lowercase(); let units = match units.as_str() { @@ -187,7 +188,7 @@ impl PyWindowFrame { } }; let start_bound = match start_bound { - Some(start_bound) => WindowFrameBound::Preceding(start_bound), + Some(start_bound) => WindowFrameBound::Preceding(start_bound.0), None => match units { WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -200,7 +201,7 @@ impl PyWindowFrame { }, }; let end_bound = match end_bound { - Some(end_bound) => WindowFrameBound::Following(end_bound), + Some(end_bound) => WindowFrameBound::Following(end_bound.0), None => match units { WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)), WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)), @@ -253,7 +254,7 @@ impl PyWindowFrameBound { matches!(self.frame_bound, WindowFrameBound::Following(_)) } /// Returns the offset of the window frame - pub fn get_offset(&self) -> PyResult> { + pub fn get_offset(&self) -> PyDataFusionResult> { match &self.frame_bound { WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => match val { x if x.is_null() => Ok(None), diff --git a/src/functions.rs b/src/functions.rs index ae032d702..46c748cf8 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -22,8 +22,10 @@ use datafusion::logical_expr::WindowFrame; use pyo3::{prelude::*, wrap_pyfunction}; use crate::common::data_type::NullTreatment; +use crate::common::data_type::PyScalarValue; use crate::context::PySessionContext; -use crate::errors::DataFusionError; +use crate::errors::PyDataFusionError; +use crate::errors::PyDataFusionResult; use crate::expr::conditional_expr::PyCaseBuilder; use crate::expr::sort_expr::to_sort_expressions; use crate::expr::sort_expr::PySortExpr; @@ -44,7 +46,7 @@ fn add_builder_fns_to_aggregate( filter: Option, order_by: Option>, null_treatment: Option, -) -> PyResult { +) -> PyDataFusionResult { // Since ExprFuncBuilder::new() is private, we can guarantee initializing // a builder with an `null_treatment` with option None let mut builder = agg_fn.null_treatment(None); @@ -228,7 +230,10 @@ fn when(when: PyExpr, then: PyExpr) -> PyResult { /// 1) If no function has been found, search default aggregate functions. /// /// NOTE: we search the built-ins first because the `UDAF` versions currently do not have the same behavior. -fn find_window_fn(name: &str, ctx: Option) -> PyResult { +fn find_window_fn( + name: &str, + ctx: Option, +) -> PyDataFusionResult { if let Some(ctx) = ctx { // search UDAFs let udaf = ctx @@ -284,7 +289,9 @@ fn find_window_fn(name: &str, ctx: Option) -> PyResult, order_by: Option>, null_treatment: Option - ) -> PyResult { + ) -> PyDataFusionResult { let agg_fn = functions_aggregate::expr_fn::$NAME($($arg.into()),*); add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) @@ -362,7 +369,7 @@ macro_rules! aggregate_function_vec_args { filter: Option, order_by: Option>, null_treatment: Option - ) -> PyResult { + ) -> PyDataFusionResult { let agg_fn = functions_aggregate::expr_fn::$NAME(vec![$($arg.into()),*]); add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) @@ -642,7 +649,7 @@ pub fn approx_percentile_cont( percentile: f64, num_centroids: Option, // enforces optional arguments at the end, currently filter: Option, -) -> PyResult { +) -> PyDataFusionResult { let args = if let Some(num_centroids) = num_centroids { vec![expression.expr, lit(percentile), lit(num_centroids)] } else { @@ -661,7 +668,7 @@ pub fn approx_percentile_cont_with_weight( weight: PyExpr, percentile: f64, filter: Option, -) -> PyResult { +) -> PyDataFusionResult { let agg_fn = functions_aggregate::expr_fn::approx_percentile_cont_with_weight( expression.expr, weight.expr, @@ -683,7 +690,7 @@ pub fn first_value( filter: Option, order_by: Option>, null_treatment: Option, -) -> PyResult { +) -> PyDataFusionResult { // If we initialize the UDAF with order_by directly, then it gets over-written by the builder let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None); @@ -700,7 +707,7 @@ pub fn nth_value( filter: Option, order_by: Option>, null_treatment: Option, -) -> PyResult { +) -> PyDataFusionResult { let agg_fn = datafusion::functions_aggregate::nth_value::nth_value(expr.expr, n, vec![]); add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) } @@ -715,7 +722,7 @@ pub fn string_agg( filter: Option, order_by: Option>, null_treatment: Option, -) -> PyResult { +) -> PyDataFusionResult { let agg_fn = datafusion::functions_aggregate::string_agg::string_agg(expr.expr, lit(delimiter)); add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by, null_treatment) } @@ -726,7 +733,7 @@ pub(crate) fn add_builder_fns_to_window( window_frame: Option, order_by: Option>, null_treatment: Option, -) -> PyResult { +) -> PyDataFusionResult { let null_treatment = null_treatment.map(|n| n.into()); let mut builder = window_fn.null_treatment(null_treatment); @@ -748,7 +755,7 @@ pub(crate) fn add_builder_fns_to_window( builder = builder.window_frame(window_frame.into()); } - builder.build().map(|e| e.into()).map_err(|err| err.into()) + Ok(builder.build().map(|e| e.into())?) } #[pyfunction] @@ -756,10 +763,11 @@ pub(crate) fn add_builder_fns_to_window( pub fn lead( arg: PyExpr, shift_offset: i64, - default_value: Option, + default_value: Option, partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { + let default_value = default_value.map(|v| v.into()); let window_fn = functions_window::expr_fn::lead(arg.expr, Some(shift_offset), default_value); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -770,10 +778,11 @@ pub fn lead( pub fn lag( arg: PyExpr, shift_offset: i64, - default_value: Option, + default_value: Option, partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { + let default_value = default_value.map(|v| v.into()); let window_fn = functions_window::expr_fn::lag(arg.expr, Some(shift_offset), default_value); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -784,7 +793,7 @@ pub fn lag( pub fn row_number( partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::row_number(); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -795,7 +804,7 @@ pub fn row_number( pub fn rank( partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::rank(); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -806,7 +815,7 @@ pub fn rank( pub fn dense_rank( partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::dense_rank(); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -817,7 +826,7 @@ pub fn dense_rank( pub fn percent_rank( partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::percent_rank(); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -828,7 +837,7 @@ pub fn percent_rank( pub fn cume_dist( partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::cume_dist(); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) @@ -840,7 +849,7 @@ pub fn ntile( arg: PyExpr, partition_by: Option>, order_by: Option>, -) -> PyResult { +) -> PyDataFusionResult { let window_fn = functions_window::expr_fn::ntile(arg.into()); add_builder_fns_to_window(window_fn, partition_by, None, order_by, None) diff --git a/src/lib.rs b/src/lib.rs index 1111d5d06..317c3a49a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ pub mod expr; mod functions; pub mod physical_plan; mod pyarrow_filter_expression; +pub mod pyarrow_util; mod record_batch; pub mod sql; pub mod store; diff --git a/src/physical_plan.rs b/src/physical_plan.rs index 9ef2f0ebb..295908dc7 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; -use crate::{context::PySessionContext, errors::DataFusionError}; +use crate::{context::PySessionContext, errors::PyDataFusionResult}; #[pyclass(name = "ExecutionPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] @@ -58,7 +58,7 @@ impl PyExecutionPlan { format!("{}", d.indent(false)) } - pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyResult> { + pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult> { let codec = DefaultPhysicalExtensionCodec {}; let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan( self.plan.clone(), @@ -70,7 +70,10 @@ impl PyExecutionPlan { } #[staticmethod] - pub fn from_proto(ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>) -> PyResult { + pub fn from_proto( + ctx: PySessionContext, + proto_msg: Bound<'_, PyBytes>, + ) -> PyDataFusionResult { let bytes: &[u8] = proto_msg.extract()?; let proto_plan = datafusion_proto::protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| { @@ -81,9 +84,7 @@ impl PyExecutionPlan { })?; let codec = DefaultPhysicalExtensionCodec {}; - let plan = proto_plan - .try_into_physical_plan(&ctx.ctx, &ctx.ctx.runtime_env(), &codec) - .map_err(DataFusionError::from)?; + let plan = proto_plan.try_into_physical_plan(&ctx.ctx, &ctx.ctx.runtime_env(), &codec)?; Ok(Self::new(plan)) } diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs index 0f97ea442..314eebf4f 100644 --- a/src/pyarrow_filter_expression.rs +++ b/src/pyarrow_filter_expression.rs @@ -21,11 +21,11 @@ use pyo3::prelude::*; use std::convert::TryFrom; use std::result::Result; -use arrow::pyarrow::ToPyArrow; use datafusion::common::{Column, ScalarValue}; use datafusion::logical_expr::{expr::InList, Between, BinaryExpr, Expr, Operator}; -use crate::errors::DataFusionError; +use crate::errors::{PyDataFusionError, PyDataFusionResult}; +use crate::pyarrow_util::scalar_to_pyarrow; #[derive(Debug)] #[repr(transparent)] @@ -34,7 +34,7 @@ pub(crate) struct PyArrowFilterExpression(PyObject); fn operator_to_py<'py>( operator: &Operator, op: &Bound<'py, PyModule>, -) -> Result, DataFusionError> { +) -> PyDataFusionResult> { let py_op: Bound<'_, PyAny> = match operator { Operator::Eq => op.getattr("eq")?, Operator::NotEq => op.getattr("ne")?, @@ -45,7 +45,7 @@ fn operator_to_py<'py>( Operator::And => op.getattr("and_")?, Operator::Or => op.getattr("or_")?, _ => { - return Err(DataFusionError::Common(format!( + return Err(PyDataFusionError::Common(format!( "Unsupported operator {operator:?}" ))) } @@ -53,8 +53,8 @@ fn operator_to_py<'py>( Ok(py_op) } -fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result, DataFusionError> { - let ret: Result, DataFusionError> = exprs +fn extract_scalar_list(exprs: &[Expr], py: Python) -> PyDataFusionResult> { + let ret = exprs .iter() .map(|expr| match expr { // TODO: should we also leverage `ScalarValue::to_pyarrow` here? @@ -71,11 +71,11 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result, Data ScalarValue::Float32(Some(f)) => Ok(f.into_py(py)), ScalarValue::Float64(Some(f)) => Ok(f.into_py(py)), ScalarValue::Utf8(Some(s)) => Ok(s.into_py(py)), - _ => Err(DataFusionError::Common(format!( + _ => Err(PyDataFusionError::Common(format!( "PyArrow can't handle ScalarValue: {v:?}" ))), }, - _ => Err(DataFusionError::Common(format!( + _ => Err(PyDataFusionError::Common(format!( "Only a list of Literals are supported got {expr:?}" ))), }) @@ -90,7 +90,7 @@ impl PyArrowFilterExpression { } impl TryFrom<&Expr> for PyArrowFilterExpression { - type Error = DataFusionError; + type Error = PyDataFusionError; // Converts a Datafusion filter Expr into an expression string that can be evaluated by Python // Note that pyarrow.compute.{field,scalar} are put into Python globals() when evaluated @@ -100,9 +100,9 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { Python::with_gil(|py| { let pc = Python::import_bound(py, "pyarrow.compute")?; let op_module = Python::import_bound(py, "operator")?; - let pc_expr: Result, DataFusionError> = match expr { + let pc_expr: PyDataFusionResult> = match expr { Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?), - Expr::Literal(scalar) => Ok(scalar.to_pyarrow(py)?.into_bound(py)), + Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar, py)?.into_bound(py)), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let operator = operator_to_py(op, &op_module)?; let left = PyArrowFilterExpression::try_from(left.as_ref())?.0; @@ -167,7 +167,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression { Ok(if *negated { invert.call1((ret,))? } else { ret }) } - _ => Err(DataFusionError::Common(format!( + _ => Err(PyDataFusionError::Common(format!( "Unsupported Datafusion expression {expr:?}" ))), }; diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs new file mode 100644 index 000000000..2b31467f8 --- /dev/null +++ b/src/pyarrow_util.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Conversions between PyArrow and DataFusion types + +use arrow::array::{Array, ArrayData}; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use datafusion::scalar::ScalarValue; +use pyo3::types::{PyAnyMethods, PyList}; +use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python}; + +use crate::common::data_type::PyScalarValue; +use crate::errors::PyDataFusionError; + +impl FromPyArrow for PyScalarValue { + fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { + let py = value.py(); + let typ = value.getattr("type")?; + let val = value.call_method0("as_py")?; + + // construct pyarrow array from the python value and pyarrow type + let factory = py.import_bound("pyarrow")?.getattr("array")?; + let args = PyList::new_bound(py, [val]); + let array = factory.call1((args, typ))?; + + // convert the pyarrow array to rust array using C data interface + let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); + let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + + Ok(PyScalarValue(scalar)) + } +} + +impl<'source> FromPyObject<'source> for PyScalarValue { + fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult { + Self::from_pyarrow_bound(value) + } +} + +pub fn scalar_to_pyarrow(scalar: &ScalarValue, py: Python) -> PyResult { + let array = scalar.to_array().map_err(PyDataFusionError::from)?; + // convert to pyarrow array using C data interface + let pyarray = array.to_data().to_pyarrow(py)?; + let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?; + + Ok(pyscalar) +} diff --git a/src/record_batch.rs b/src/record_batch.rs index eacdb5867..ec61c263f 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::errors::PyDataFusionError; use crate::utils::wait_for_future; use datafusion::arrow::pyarrow::ToPyArrow; use datafusion::arrow::record_batch::RecordBatch; @@ -90,7 +91,7 @@ async fn next_stream( let mut stream = stream.lock().await; match stream.next().await { Some(Ok(batch)) => Ok(batch.into()), - Some(Err(e)) => Err(e.into()), + Some(Err(e)) => Err(PyDataFusionError::from(e))?, None => { // Depending on whether the iteration is sync or not, we raise either a // StopIteration or a StopAsyncIteration diff --git a/src/sql/exceptions.rs b/src/sql/exceptions.rs index c458402a0..cfb02274b 100644 --- a/src/sql/exceptions.rs +++ b/src/sql/exceptions.rs @@ -17,13 +17,7 @@ use std::fmt::{Debug, Display}; -use pyo3::{create_exception, PyErr}; - -// Identifies exceptions that occur while attempting to generate a `LogicalPlan` from a SQL string -create_exception!(rust, ParsingException, pyo3::exceptions::PyException); - -// Identifies exceptions that occur during attempts to optimization an existing `LogicalPlan` -create_exception!(rust, OptimizationException, pyo3::exceptions::PyException); +use pyo3::PyErr; pub fn py_type_err(e: impl Debug + Display) -> PyErr { PyErr::new::(format!("{e}")) @@ -33,10 +27,6 @@ pub fn py_runtime_err(e: impl Debug + Display) -> PyErr { PyErr::new::(format!("{e}")) } -pub fn py_parsing_exp(e: impl Debug + Display) -> PyErr { - PyErr::new::(format!("{e}")) -} - -pub fn py_optimization_exp(e: impl Debug + Display) -> PyErr { - PyErr::new::(format!("{e}")) +pub fn py_value_err(e: impl Debug + Display) -> PyErr { + PyErr::new::(format!("{e}")) } diff --git a/src/sql/logical.rs b/src/sql/logical.rs index a541889c7..1be33b75f 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::errors::PyDataFusionResult; use crate::expr::aggregate::PyAggregate; use crate::expr::analyze::PyAnalyze; use crate::expr::distinct::PyDistinct; @@ -34,7 +35,7 @@ use crate::expr::table_scan::PyTableScan; use crate::expr::unnest::PyUnnest; use crate::expr::window::PyWindowExpr; use crate::{context::PySessionContext, errors::py_unsupported_variant_err}; -use datafusion::{error::DataFusionError, logical_expr::LogicalPlan}; +use datafusion::logical_expr::LogicalPlan; use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec}; use prost::Message; use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; @@ -125,7 +126,7 @@ impl PyLogicalPlan { format!("{}", self.plan.display_graphviz()) } - pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyResult> { + pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult> { let codec = DefaultLogicalExtensionCodec {}; let proto = datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan, &codec)?; @@ -135,7 +136,10 @@ impl PyLogicalPlan { } #[staticmethod] - pub fn from_proto(ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>) -> PyResult { + pub fn from_proto( + ctx: PySessionContext, + proto_msg: Bound<'_, PyBytes>, + ) -> PyDataFusionResult { let bytes: &[u8] = proto_msg.extract()?; let proto_plan = datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| { @@ -146,9 +150,7 @@ impl PyLogicalPlan { })?; let codec = DefaultLogicalExtensionCodec {}; - let plan = proto_plan - .try_into_logical_plan(&ctx.ctx, &codec) - .map_err(DataFusionError::from)?; + let plan = proto_plan.try_into_logical_plan(&ctx.ctx, &codec)?; Ok(Self::new(plan)) } } diff --git a/src/substrait.rs b/src/substrait.rs index 16e8c9507..8dcf3e8a7 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -18,7 +18,7 @@ use pyo3::{prelude::*, types::PyBytes}; use crate::context::PySessionContext; -use crate::errors::{py_datafusion_err, DataFusionError}; +use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; @@ -39,7 +39,7 @@ impl PyPlan { let mut proto_bytes = Vec::::new(); self.plan .encode(&mut proto_bytes) - .map_err(DataFusionError::EncodeError)?; + .map_err(PyDataFusionError::EncodeError)?; Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into()) } } @@ -66,41 +66,47 @@ pub struct PySubstraitSerializer; #[pymethods] impl PySubstraitSerializer { #[staticmethod] - pub fn serialize(sql: &str, ctx: PySessionContext, path: &str, py: Python) -> PyResult<()> { - wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path)) - .map_err(DataFusionError::from)?; + pub fn serialize( + sql: &str, + ctx: PySessionContext, + path: &str, + py: Python, + ) -> PyDataFusionResult<()> { + wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?; Ok(()) } #[staticmethod] - pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { - match PySubstraitSerializer::serialize_bytes(sql, ctx, py) { - Ok(proto_bytes) => { - let proto_bytes = proto_bytes.bind(py).downcast::().unwrap(); - PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py) - } - Err(e) => Err(py_datafusion_err(e)), - } + pub fn serialize_to_plan( + sql: &str, + ctx: PySessionContext, + py: Python, + ) -> PyDataFusionResult { + PySubstraitSerializer::serialize_bytes(sql, ctx, py).and_then(|proto_bytes| { + let proto_bytes = proto_bytes.bind(py).downcast::().unwrap(); + PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py) + }) } #[staticmethod] - pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) -> PyResult { - let proto_bytes: Vec = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx)) - .map_err(DataFusionError::from)?; + pub fn serialize_bytes( + sql: &str, + ctx: PySessionContext, + py: Python, + ) -> PyDataFusionResult { + let proto_bytes: Vec = wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))?; Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into()) } #[staticmethod] - pub fn deserialize(path: &str, py: Python) -> PyResult { - let plan = - wait_for_future(py, serializer::deserialize(path)).map_err(DataFusionError::from)?; + pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult { + let plan = wait_for_future(py, serializer::deserialize(path))?; Ok(PyPlan { plan: *plan }) } #[staticmethod] - pub fn deserialize_bytes(proto_bytes: Vec, py: Python) -> PyResult { - let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes)) - .map_err(DataFusionError::from)?; + pub fn deserialize_bytes(proto_bytes: Vec, py: Python) -> PyDataFusionResult { + let plan = wait_for_future(py, serializer::deserialize_bytes(proto_bytes))?; Ok(PyPlan { plan: *plan }) } } @@ -134,10 +140,10 @@ impl PySubstraitConsumer { ctx: &mut PySessionContext, plan: PyPlan, py: Python, - ) -> PyResult { + ) -> PyDataFusionResult { let session_state = ctx.ctx.state(); let result = consumer::from_substrait_plan(&session_state, &plan.plan); - let logical_plan = wait_for_future(py, result).map_err(DataFusionError::from)?; + let logical_plan = wait_for_future(py, result)?; Ok(PyLogicalPlan::new(logical_plan)) } } diff --git a/src/udaf.rs b/src/udaf.rs index a6aa59ac3..5f21533e0 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -28,6 +28,7 @@ use datafusion::logical_expr::{ create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF, }; +use crate::common::data_type::PyScalarValue; use crate::expr::PyExpr; use crate::utils::parse_volatility; @@ -44,13 +45,25 @@ impl RustAccumulator { impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { - Python::with_gil(|py| self.accum.bind(py).call_method0("state")?.extract()) - .map_err(|e| DataFusionError::Execution(format!("{e}"))) + Python::with_gil(|py| { + self.accum + .bind(py) + .call_method0("state")? + .extract::>() + }) + .map(|v| v.into_iter().map(|x| x.0).collect()) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn evaluate(&mut self) -> Result { - Python::with_gil(|py| self.accum.bind(py).call_method0("evaluate")?.extract()) - .map_err(|e| DataFusionError::Execution(format!("{e}"))) + Python::with_gil(|py| { + self.accum + .bind(py) + .call_method0("evaluate")? + .extract::() + }) + .map(|v| v.0) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { diff --git a/src/udwf.rs b/src/udwf.rs index 689eb79e3..04a4a1640 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -26,6 +26,7 @@ use datafusion::scalar::ScalarValue; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use crate::common::data_type::PyScalarValue; use crate::expr::PyExpr; use crate::utils::parse_volatility; use datafusion::arrow::datatypes::DataType; @@ -133,7 +134,8 @@ impl PartitionEvaluator for RustPartitionEvaluator { self.evaluator .bind(py) .call_method1("evaluate", py_args) - .and_then(|v| v.extract()) + .and_then(|v| v.extract::()) + .map(|v| v.0) .map_err(|e| DataFusionError::Execution(format!("{e}"))) }) } diff --git a/src/utils.rs b/src/utils.rs index 795589752..ed224b364 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::errors::DataFusionError; +use crate::errors::{PyDataFusionError, PyDataFusionResult}; use crate::TokioRuntime; use datafusion::logical_expr::Volatility; use pyo3::exceptions::PyValueError; @@ -47,13 +47,13 @@ where py.allow_threads(|| runtime.block_on(f)) } -pub(crate) fn parse_volatility(value: &str) -> Result { +pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult { Ok(match value { "immutable" => Volatility::Immutable, "stable" => Volatility::Stable, "volatile" => Volatility::Volatile, value => { - return Err(DataFusionError::Common(format!( + return Err(PyDataFusionError::Common(format!( "Unsupportad volatility type: `{value}`, supported \ values are: immutable, stable and volatile." )))