diff --git a/Cargo.lock b/Cargo.lock index bf67256c3..2b62a69dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1594,6 +1594,7 @@ dependencies = [ "log", "mimalloc", "object_store", + "parking_lot", "prost", "prost-types", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index f1d1a0236..2c48bdd5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ futures = "0.3" object_store = { version = "0.12.3", features = ["aws", "gcp", "azure", "http"] } url = "2" log = "0.4.27" +parking_lot = "0.12" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index e201db71e..e8a0398b8 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -137,6 +137,67 @@ and you want to create a sharable FFI counterpart, you could write: let my_provider = MyTableProvider::default(); let ffi_provider = FFI_TableProvider::new(Arc::new(my_provider), false, None); +.. _ffi_pyclass_mutability: + +PyO3 class mutability guidelines +-------------------------------- + +PyO3 bindings should present immutable wrappers whenever a struct stores shared or +interior-mutable state. In practice this means that any ``#[pyclass]`` containing an +``Arc>`` or similar synchronized primitive must opt into ``#[pyclass(frozen)]`` +unless there is a compelling reason not to. + +The :mod:`datafusion` configuration helpers illustrate the preferred pattern. The +``PyConfig`` class in :file:`src/config.rs` stores an ``Arc>`` and is +explicitly frozen so callers interact with configuration state through provided methods +instead of mutating the container directly: + +.. code-block:: rust + + #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] + #[derive(Clone)] + pub(crate) struct PyConfig { + config: Arc>, + } + +The same approach applies to execution contexts. ``PySessionContext`` in +:file:`src/context.rs` stays frozen even though it shares mutable state internally via +``SessionContext``. This ensures PyO3 tracks borrows correctly while Python-facing APIs +clone the inner ``SessionContext`` or return new wrappers instead of mutating the +existing instance in place: + +.. code-block:: rust + + #[pyclass(frozen, name = "SessionContext", module = "datafusion", subclass)] + #[derive(Clone)] + pub struct PySessionContext { + pub ctx: SessionContext, + } + +Occasionally a type must remain mutable—for example when PyO3 attribute setters need to +update fields directly. In these rare cases add an inline justification so reviewers and +future contributors understand why ``frozen`` is unsafe to enable. ``DataTypeMap`` in +:file:`src/common/data_type.rs` includes such a comment because PyO3 still needs to track +field updates: + +.. code-block:: rust + + // TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now + #[derive(Debug, Clone)] + #[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)] + pub struct DataTypeMap { + #[pyo3(get, set)] + pub arrow_type: PyDataType, + #[pyo3(get, set)] + pub python_type: PythonType, + #[pyo3(get, set)] + pub sql_type: SqlType, + } + +When reviewers encounter a mutable ``#[pyclass]`` without a comment, they should request +an explanation or ask that ``frozen`` be added. Keeping these wrappers frozen by default +helps avoid subtle bugs stemming from PyO3's interior mutability tracking. + If you were interfacing with a library that provided the above ``FFI_TableProvider`` and you needed to turn it back into an ``TableProvider``, you can turn it into a ``ForeignTableProvider`` with implements the ``TableProvider`` trait. diff --git a/docs/source/contributor-guide/introduction.rst b/docs/source/contributor-guide/introduction.rst index 6cb05c62d..33c2b274c 100644 --- a/docs/source/contributor-guide/introduction.rst +++ b/docs/source/contributor-guide/introduction.rst @@ -26,6 +26,10 @@ We welcome and encourage contributions of all kinds, such as: In addition to submitting new PRs, we have a healthy tradition of community members reviewing each other’s PRs. Doing so is a great way to help the community as well as get more familiar with Rust and the relevant codebases. +Before opening a pull request that touches PyO3 bindings, please review the +:ref:`PyO3 class mutability guidelines ` so you can flag missing +``#[pyclass(frozen)]`` annotations during development and review. + How to develop -------------- diff --git a/pyproject.toml b/pyproject.toml index edecc4588..69d31ec9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,9 @@ convention = "google" [tool.ruff.lint.pycodestyle] max-doc-length = 88 +[tool.ruff.lint.flake8-boolean-trap] +extend-allowed-calls = ["lit", "datafusion.lit"] + # Disable docstring checking for these directories [tool.ruff.lint.per-file-ignores] "python/tests/*" = [ diff --git a/python/tests/test_concurrency.py b/python/tests/test_concurrency.py new file mode 100644 index 000000000..f790f9473 --- /dev/null +++ b/python/tests/test_concurrency.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +import pyarrow as pa +from datafusion import Config, SessionContext, col, lit +from datafusion import functions as f +from datafusion.common import SqlSchema + + +def _run_in_threads(fn, count: int = 8) -> None: + with ThreadPoolExecutor(max_workers=count) as executor: + futures = [executor.submit(fn, i) for i in range(count)] + for future in futures: + # Propagate any exception raised in the worker thread. + future.result() + + +def test_concurrent_access_to_shared_structures() -> None: + """Exercise SqlSchema, Config, and DataFrame concurrently.""" + + schema = SqlSchema("concurrency") + config = Config() + ctx = SessionContext() + + batch = pa.record_batch([pa.array([1, 2, 3], type=pa.int32())], names=["value"]) + df = ctx.create_dataframe([[batch]]) + + config_key = "datafusion.execution.batch_size" + expected_rows = batch.num_rows + + def worker(index: int) -> None: + schema.name = f"concurrency-{index}" + assert schema.name.startswith("concurrency-") + # Exercise getters that use internal locks. + assert isinstance(schema.tables, list) + assert isinstance(schema.views, list) + assert isinstance(schema.functions, list) + + config.set(config_key, str(1024 + index)) + assert config.get(config_key) is not None + # Access the full config map to stress lock usage. + assert config_key in config.get_all() + + batches = df.collect() + assert sum(batch.num_rows for batch in batches) == expected_rows + + _run_in_threads(worker, count=12) + + +def test_config_set_during_get_all() -> None: + """Ensure config writes proceed while another thread reads all entries.""" + + config = Config() + key = "datafusion.execution.batch_size" + + def reader() -> None: + for _ in range(200): + # get_all should not hold the lock while converting to Python objects + config.get_all() + + def writer() -> None: + for index in range(200): + config.set(key, str(1024 + index)) + + with ThreadPoolExecutor(max_workers=2) as executor: + reader_future = executor.submit(reader) + writer_future = executor.submit(writer) + reader_future.result(timeout=10) + writer_future.result(timeout=10) + + assert config.get(key) is not None + + +def test_case_builder_reuse_from_multiple_threads() -> None: + """Ensure the case builder can be safely reused across threads.""" + + ctx = SessionContext() + values = pa.array([0, 1, 2, 3, 4], type=pa.int32()) + df = ctx.create_dataframe([[pa.record_batch([values], names=["value"])]]) + + base_builder = f.case(col("value")) + + def add_case(i: int) -> None: + nonlocal base_builder + base_builder = base_builder.when(lit(i), lit(f"value-{i}")) + + _run_in_threads(add_case, count=8) + + with ThreadPoolExecutor(max_workers=2) as executor: + otherwise_future = executor.submit(base_builder.otherwise, lit("default")) + case_expr = otherwise_future.result() + + result = df.select(case_expr.alias("label")).collect() + assert sum(batch.num_rows for batch in result) == len(values) + + predicate_builder = f.when(col("value") == lit(0), lit("zero")) + + def add_predicate(i: int) -> None: + predicate_builder.when(col("value") == lit(i + 1), lit(f"value-{i + 1}")) + + _run_in_threads(add_predicate, count=4) + + with ThreadPoolExecutor(max_workers=2) as executor: + end_future = executor.submit(predicate_builder.end) + predicate_expr = end_future.result() + + result = df.select(predicate_expr.alias("label")).collect() + assert sum(batch.num_rows for batch in result) == len(values) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 810d419cf..7847826ac 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -16,6 +16,7 @@ # under the License. import re +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone import pyarrow as pa @@ -200,6 +201,98 @@ def traverse_logical_plan(plan): assert not variant.negated() +def test_case_builder_error_preserves_builder_state(): + case_builder = functions.when(lit(True), lit(1)) + + with pytest.raises(Exception) as exc_info: + _ = case_builder.otherwise(lit("bad")) + + err_msg = str(exc_info.value) + assert "multiple data types" in err_msg + assert "CaseBuilder has already been consumed" not in err_msg + + _ = case_builder.end() + + err_msg = str(exc_info.value) + assert "multiple data types" in err_msg + assert "CaseBuilder has already been consumed" not in err_msg + + +def test_case_builder_success_preserves_builder_state(): + ctx = SessionContext() + df = ctx.from_pydict({"flag": [False]}, name="tbl") + + case_builder = functions.when(col("flag"), lit("true")) + + expr_default_one = case_builder.otherwise(lit("default-1")).alias("result") + result_one = df.select(expr_default_one).collect() + assert result_one[0].column(0).to_pylist() == ["default-1"] + + expr_default_two = case_builder.otherwise(lit("default-2")).alias("result") + result_two = df.select(expr_default_two).collect() + assert result_two[0].column(0).to_pylist() == ["default-2"] + + expr_end_one = case_builder.end().alias("result") + end_one = df.select(expr_end_one).collect() + assert end_one[0].column(0).to_pylist() == [None] + + +def test_case_builder_when_handles_are_independent(): + ctx = SessionContext() + df = ctx.from_pydict( + { + "flag": [True, False, False, False], + "value": [1, 15, 25, 5], + }, + name="tbl", + ) + + base_builder = functions.when(col("flag"), lit("flag-true")) + + first_builder = base_builder.when(col("value") > lit(10), lit("gt10")) + second_builder = base_builder.when(col("value") > lit(20), lit("gt20")) + + first_builder = first_builder.when(lit(True), lit("final-one")) + + expr_first = first_builder.otherwise(lit("fallback-one")).alias("first") + expr_second = second_builder.otherwise(lit("fallback-two")).alias("second") + + result = df.select(expr_first, expr_second).collect()[0] + + assert result.column(0).to_pylist() == [ + "flag-true", + "gt10", + "gt10", + "final-one", + ] + assert result.column(1).to_pylist() == [ + "flag-true", + "fallback-two", + "gt20", + "fallback-two", + ] + + +def test_case_builder_when_thread_safe(): + case_builder = functions.when(lit(True), lit(1)) + + def build_expr(value: int) -> bool: + builder = case_builder.when(lit(True), lit(value)) + builder.otherwise(lit(value)) + return True + + with ThreadPoolExecutor(max_workers=8) as executor: + futures = [executor.submit(build_expr, idx) for idx in range(16)] + results = [future.result() for future in futures] + + assert all(results) + + # Ensure the shared builder remains usable after concurrent `when` calls. + follow_up_builder = case_builder.when(lit(True), lit(42)) + assert isinstance(follow_up_builder, type(case_builder)) + follow_up_builder.otherwise(lit(7)) + + def test_expr_getitem() -> None: ctx = SessionContext() data = { diff --git a/python/tests/test_pyclass_frozen.py b/python/tests/test_pyclass_frozen.py new file mode 100644 index 000000000..189ea8dec --- /dev/null +++ b/python/tests/test_pyclass_frozen.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Ensure exposed pyclasses default to frozen.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Iterator + +PYCLASS_RE = re.compile( + r"#\[\s*pyclass\s*(?:\((?P.*?)\))?\s*\]", + re.DOTALL, +) +ARG_STRING_RE = re.compile( + r"(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P[^\"]+)\"", +) +STRUCT_NAME_RE = re.compile( + r"\b(?:pub\s+)?(?:struct|enum)\s+" + r"(?P[A-Za-z_][A-Za-z0-9_]*)", +) + + +@dataclass +class PyClass: + module: str + name: str + frozen: bool + source: Path + + +def iter_pyclasses(root: Path) -> Iterator[PyClass]: + for path in root.rglob("*.rs"): + text = path.read_text(encoding="utf8") + for match in PYCLASS_RE.finditer(text): + args = match.group("args") or "" + frozen = re.search(r"\bfrozen\b", args) is not None + + module = None + name = None + for arg_match in ARG_STRING_RE.finditer(args): + key = arg_match.group("key") + value = arg_match.group("value") + if key == "module": + module = value + elif key == "name": + name = value + + remainder = text[match.end() :] + struct_match = STRUCT_NAME_RE.search(remainder) + struct_name = struct_match.group("name") if struct_match else None + + yield PyClass( + module=module or "datafusion", + name=name or struct_name or "", + frozen=frozen, + source=path, + ) + + +def test_pyclasses_are_frozen() -> None: + allowlist = { + # NOTE: Any new exceptions must include a justification comment + # in the Rust source and, ideally, a follow-up issue to remove + # the exemption. + ("datafusion.common", "SqlTable"), + ("datafusion.common", "SqlView"), + ("datafusion.common", "DataTypeMap"), + ("datafusion.expr", "TryCast"), + ("datafusion.expr", "WriteOp"), + } + + unfrozen = [ + pyclass + for pyclass in iter_pyclasses(Path("src")) + if not pyclass.frozen and (pyclass.module, pyclass.name) not in allowlist + ] + + if unfrozen: + msg = ( + "Found pyclasses missing `frozen`; add them to the allowlist only " + "with a justification comment and follow-up plan:\n" + ) + msg += "\n".join( + (f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})") + for pyclass in unfrozen + ) + assert not unfrozen, msg diff --git a/src/catalog.rs b/src/catalog.rs index 17d4ec3b8..b5fa3da72 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -36,19 +36,19 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawCatalog", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawSchema", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PySchema { pub schema: Arc, } -#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] +#[pyclass(frozen, name = "RawTable", module = "datafusion.catalog", subclass)] #[derive(Clone)] pub struct PyTable { pub table: Arc, diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 4d7743397..3cbe31332 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -37,7 +37,7 @@ impl From for ScalarValue { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "RexType", module = "datafusion.common")] pub enum RexType { Alias, Literal, @@ -56,6 +56,7 @@ pub enum RexType { /// and manageable location. Therefore this structure exists /// to map those types and provide a simple place for developers /// to map types from one system to another. +// TODO: This looks like this needs pyo3 tracking so leaving unfrozen for now #[derive(Debug, Clone)] #[pyclass(name = "DataTypeMap", module = "datafusion.common", subclass)] pub struct DataTypeMap { @@ -577,7 +578,7 @@ impl DataTypeMap { /// Since `DataType` exists in another package we cannot make that happen here so we wrap /// `DataType` as `PyDataType` This exists solely to satisfy those constraints. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(name = "DataType", module = "datafusion.common")] +#[pyclass(frozen, name = "DataType", module = "datafusion.common")] pub struct PyDataType { pub data_type: DataType, } @@ -635,7 +636,7 @@ impl From for PyDataType { /// Represents the possible Python types that can be mapped to the SQL types #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "PythonType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "PythonType", module = "datafusion.common")] pub enum PythonType { Array, Bool, @@ -655,7 +656,7 @@ pub enum PythonType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "SqlType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "SqlType", module = "datafusion.common")] pub enum SqlType { ANY, ARRAY, @@ -713,7 +714,13 @@ pub enum SqlType { #[allow(non_camel_case_types)] #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "NullTreatment", module = "datafusion.common")] +#[pyclass( + frozen, + eq, + eq_int, + name = "NullTreatment", + module = "datafusion.common" +)] pub enum NullTreatment { IGNORE_NULLS, RESPECT_NULLS, diff --git a/src/common/df_schema.rs b/src/common/df_schema.rs index 4e1d84060..eb62469cf 100644 --- a/src/common/df_schema.rs +++ b/src/common/df_schema.rs @@ -21,7 +21,7 @@ use datafusion::common::DFSchema; use pyo3::prelude::*; #[derive(Debug, Clone)] -#[pyclass(name = "DFSchema", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "DFSchema", module = "datafusion.common", subclass)] pub struct PyDFSchema { schema: Arc, } diff --git a/src/common/function.rs b/src/common/function.rs index a8d752f16..bc6f23160 100644 --- a/src/common/function.rs +++ b/src/common/function.rs @@ -22,7 +22,7 @@ use pyo3::prelude::*; use super::data_type::PyDataType; -#[pyclass(name = "SqlFunction", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "SqlFunction", module = "datafusion.common", subclass)] #[derive(Debug, Clone)] pub struct SqlFunction { pub name: String, diff --git a/src/common/schema.rs b/src/common/schema.rs index 752c39bde..14ab630d3 100644 --- a/src/common/schema.rs +++ b/src/common/schema.rs @@ -33,17 +33,15 @@ use crate::sql::logical::PyLogicalPlan; use super::{data_type::DataTypeMap, function::SqlFunction}; -#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)] +use parking_lot::RwLock; + +#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass, frozen)] #[derive(Debug, Clone)] pub struct SqlSchema { - #[pyo3(get, set)] - pub name: String, - #[pyo3(get, set)] - pub tables: Vec, - #[pyo3(get, set)] - pub views: Vec, - #[pyo3(get, set)] - pub functions: Vec, + name: Arc>, + tables: Arc>>, + views: Arc>>, + functions: Arc>>, } #[pyclass(name = "SqlTable", module = "datafusion.common", subclass)] @@ -104,28 +102,70 @@ impl SqlSchema { #[new] pub fn new(schema_name: &str) -> Self { Self { - name: schema_name.to_owned(), - tables: Vec::new(), - views: Vec::new(), - functions: Vec::new(), + name: Arc::new(RwLock::new(schema_name.to_owned())), + tables: Arc::new(RwLock::new(Vec::new())), + views: Arc::new(RwLock::new(Vec::new())), + functions: Arc::new(RwLock::new(Vec::new())), } } + #[getter] + fn name(&self) -> PyResult { + Ok(self.name.read().clone()) + } + + #[setter] + fn set_name(&self, value: String) -> PyResult<()> { + *self.name.write() = value; + Ok(()) + } + + #[getter] + fn tables(&self) -> PyResult> { + Ok(self.tables.read().clone()) + } + + #[setter] + fn set_tables(&self, tables: Vec) -> PyResult<()> { + *self.tables.write() = tables; + Ok(()) + } + + #[getter] + fn views(&self) -> PyResult> { + Ok(self.views.read().clone()) + } + + #[setter] + fn set_views(&self, views: Vec) -> PyResult<()> { + *self.views.write() = views; + Ok(()) + } + + #[getter] + fn functions(&self) -> PyResult> { + Ok(self.functions.read().clone()) + } + + #[setter] + fn set_functions(&self, functions: Vec) -> PyResult<()> { + *self.functions.write() = functions; + Ok(()) + } + pub fn table_by_name(&self, table_name: &str) -> Option { - for tbl in &self.tables { - if tbl.name.eq(table_name) { - return Some(tbl.clone()); - } - } - None + let tables = self.tables.read(); + tables.iter().find(|tbl| tbl.name.eq(table_name)).cloned() } - pub fn add_table(&mut self, table: SqlTable) { - self.tables.push(table); + pub fn add_table(&self, table: SqlTable) { + let mut tables = self.tables.write(); + tables.push(table); } - pub fn drop_table(&mut self, table_name: String) { - self.tables.retain(|x| !x.name.eq(&table_name)); + pub fn drop_table(&self, table_name: String) { + let mut tables = self.tables.write(); + tables.retain(|x| !x.name.eq(&table_name)); } } @@ -208,7 +248,7 @@ fn is_supported_push_down_expr(_expr: &Expr) -> bool { true } -#[pyclass(name = "SqlStatistics", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "SqlStatistics", module = "datafusion.common", subclass)] #[derive(Debug, Clone)] pub struct SqlStatistics { row_count: f64, @@ -227,7 +267,7 @@ impl SqlStatistics { } } -#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Constraints", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyConstraints { pub constraints: Constraints, @@ -252,7 +292,7 @@ impl Display for PyConstraints { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")] +#[pyclass(frozen, eq, eq_int, name = "TableType", module = "datafusion.common")] pub enum PyTableType { Base, View, @@ -279,7 +319,7 @@ impl From for PyTableType { } } -#[pyclass(name = "TableSource", module = "datafusion.common", subclass)] +#[pyclass(frozen, name = "TableSource", module = "datafusion.common", subclass)] #[derive(Clone)] pub struct PyTableSource { pub table_source: Arc, diff --git a/src/config.rs b/src/config.rs index 20f22196c..1726e5d9b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use pyo3::prelude::*; use pyo3::types::*; @@ -22,11 +24,11 @@ use datafusion::config::ConfigOptions; use crate::errors::PyDataFusionResult; use crate::utils::py_obj_to_scalar_value; - -#[pyclass(name = "Config", module = "datafusion", subclass)] +use parking_lot::RwLock; +#[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { - config: ConfigOptions, + config: Arc>, } #[pymethods] @@ -34,7 +36,7 @@ impl PyConfig { #[new] fn py_new() -> Self { Self { - config: ConfigOptions::new(), + config: Arc::new(RwLock::new(ConfigOptions::new())), } } @@ -42,41 +44,54 @@ impl PyConfig { #[staticmethod] pub fn from_env() -> PyDataFusionResult { Ok(Self { - config: ConfigOptions::from_env()?, + config: Arc::new(RwLock::new(ConfigOptions::from_env()?)), }) } /// Get a configuration option - pub fn get<'py>(&mut self, key: &str, py: Python<'py>) -> PyResult> { - let options = self.config.to_owned(); - for entry in options.entries() { - if entry.key == key { - return Ok(entry.value.into_pyobject(py)?); - } + pub fn get<'py>(&self, key: &str, py: Python<'py>) -> PyResult> { + let value: Option> = { + let options = self.config.read(); + options + .entries() + .into_iter() + .find_map(|entry| (entry.key == key).then_some(entry.value.clone())) + }; + + match value { + Some(value) => Ok(value.into_pyobject(py)?), + None => Ok(None::.into_pyobject(py)?), } - Ok(None::.into_pyobject(py)?) } /// Set a configuration option - pub fn set(&mut self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { + pub fn set(&self, key: &str, value: PyObject, py: Python) -> PyDataFusionResult<()> { let scalar_value = py_obj_to_scalar_value(py, value)?; - self.config.set(key, scalar_value.to_string().as_str())?; + let mut options = self.config.write(); + options.set(key, scalar_value.to_string().as_str())?; Ok(()) } /// Get all configuration options - pub fn get_all(&mut self, py: Python) -> PyResult { + pub fn get_all(&self, py: Python) -> PyResult { + let entries: Vec<(String, Option)> = { + let options = self.config.read(); + options + .entries() + .into_iter() + .map(|entry| (entry.key.clone(), entry.value.clone())) + .collect() + }; + let dict = PyDict::new(py); - let options = self.config.to_owned(); - for entry in options.entries() { - dict.set_item(entry.key, entry.value.clone().into_pyobject(py)?)?; + for (key, value) in entries { + dict.set_item(key, value.into_pyobject(py)?)?; } Ok(dict.into()) } - fn __repr__(&mut self, py: Python) -> PyResult { - let dict = self.get_all(py); - match dict { + fn __repr__(&self, py: Python) -> PyResult { + match self.get_all(py) { Ok(result) => Ok(format!("Config({result})")), Err(err) => Ok(format!("Error: {:?}", err.to_string())), } diff --git a/src/context.rs b/src/context.rs index 0ccb03261..e3f978ee1 100644 --- a/src/context.rs +++ b/src/context.rs @@ -77,7 +77,7 @@ use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext -#[pyclass(name = "SessionConfig", module = "datafusion", subclass)] +#[pyclass(frozen, name = "SessionConfig", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PySessionConfig { pub config: SessionConfig, @@ -170,7 +170,7 @@ impl PySessionConfig { } /// Runtime options for a SessionContext -#[pyclass(name = "RuntimeEnvBuilder", module = "datafusion", subclass)] +#[pyclass(frozen, name = "RuntimeEnvBuilder", module = "datafusion", subclass)] #[derive(Clone)] pub struct PyRuntimeEnvBuilder { pub builder: RuntimeEnvBuilder, @@ -257,7 +257,7 @@ impl PyRuntimeEnvBuilder { } /// `PySQLOptions` allows you to specify options to the sql execution. -#[pyclass(name = "SQLOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "SQLOptions", module = "datafusion", subclass)] #[derive(Clone)] pub struct PySQLOptions { pub options: SQLOptions, diff --git a/src/dataframe.rs b/src/dataframe.rs index 5882acf76..555a8500d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -58,10 +58,17 @@ use crate::{ expr::{sort_expr::PySortExpr, PyExpr}, }; +use parking_lot::Mutex; + +// Type aliases to simplify very complex types used in this file and +// avoid compiler complaints about deeply nested types in struct fields. +type CachedBatches = Option<(Vec, bool)>; +type SharedCachedBatches = Arc>; + // https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116 // - we have not decided on the table_provider approach yet // this is an interim implementation -#[pyclass(name = "TableProvider", module = "datafusion")] +#[pyclass(frozen, name = "TableProvider", module = "datafusion")] pub struct PyTableProvider { provider: Arc, } @@ -188,7 +195,7 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult< } /// Python mapping of `ParquetOptions` (includes just the writer-related options). -#[pyclass(name = "ParquetWriterOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ParquetWriterOptions", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PyParquetWriterOptions { options: ParquetOptions, @@ -249,7 +256,7 @@ impl PyParquetWriterOptions { } /// Python mapping of `ParquetColumnOptions`. -#[pyclass(name = "ParquetColumnOptions", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ParquetColumnOptions", module = "datafusion", subclass)] #[derive(Clone, Default)] pub struct PyParquetColumnOptions { options: ParquetColumnOptions, @@ -284,13 +291,13 @@ impl PyParquetColumnOptions { /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. -#[pyclass(name = "DataFrame", module = "datafusion", subclass)] +#[pyclass(name = "DataFrame", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub struct PyDataFrame { df: Arc, // In IPython environment cache batches between __repr__ and _repr_html_ calls. - batches: Option<(Vec, bool)>, + batches: SharedCachedBatches, } impl PyDataFrame { @@ -298,16 +305,24 @@ impl PyDataFrame { pub fn new(df: DataFrame) -> Self { Self { df: Arc::new(df), - batches: None, + batches: Arc::new(Mutex::new(None)), } } - fn prepare_repr_string(&mut self, py: Python, as_html: bool) -> PyDataFusionResult { + fn prepare_repr_string(&self, py: Python, as_html: bool) -> PyDataFusionResult { // Get the Python formatter and config let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?; - let should_cache = *is_ipython_env(py) && self.batches.is_none(); - let (batches, has_more) = match self.batches.take() { + let is_ipython = *is_ipython_env(py); + + let (cached_batches, should_cache) = { + let mut cache = self.batches.lock(); + let should_cache = is_ipython && cache.is_none(); + let batches = cache.take(); + (batches, should_cache) + }; + + let (batches, has_more) = match cached_batches { Some(b) => b, None => wait_for_future( py, @@ -346,7 +361,8 @@ impl PyDataFrame { let html_str: String = html_result.extract()?; if should_cache { - self.batches = Some((batches, has_more)); + let mut cache = self.batches.lock(); + *cache = Some((batches.clone(), has_more)); } Ok(html_str) @@ -376,7 +392,7 @@ impl PyDataFrame { } } - fn __repr__(&mut self, py: Python) -> PyDataFusionResult { + fn __repr__(&self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, false) } @@ -411,7 +427,7 @@ impl PyDataFrame { Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}")) } - fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult { + fn _repr_html_(&self, py: Python) -> PyDataFusionResult { self.prepare_repr_string(py, true) } @@ -874,7 +890,7 @@ impl PyDataFrame { #[pyo3(signature = (requested_schema=None))] fn __arrow_c_stream__<'py>( - &'py mut self, + &'py self, py: Python<'py>, requested_schema: Option>, ) -> PyDataFusionResult> { diff --git a/src/expr.rs b/src/expr.rs index e2c53025c..c9eddaa2d 100644 --- a/src/expr.rs +++ b/src/expr.rs @@ -115,7 +115,7 @@ pub mod window; use sort_expr::{to_sort_expressions, PySortExpr}; /// A PyExpr that can be used on a DataFrame -#[pyclass(name = "RawExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "RawExpr", module = "datafusion.expr", subclass)] #[derive(Debug, Clone)] pub struct PyExpr { pub expr: Expr, @@ -637,7 +637,7 @@ impl PyExpr { } } -#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ExprFuncBuilder", module = "datafusion.expr", subclass)] #[derive(Debug, Clone)] pub struct PyExprFuncBuilder { pub builder: ExprFuncBuilder, diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index fd4393271..4af7c755a 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -28,7 +28,7 @@ use crate::errors::py_type_err; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Aggregate", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Aggregate", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAggregate { aggregate: Aggregate, diff --git a/src/expr/aggregate_expr.rs b/src/expr/aggregate_expr.rs index 7c5d3d31f..72ba0638f 100644 --- a/src/expr/aggregate_expr.rs +++ b/src/expr/aggregate_expr.rs @@ -20,7 +20,12 @@ use datafusion::logical_expr::expr::AggregateFunction; use pyo3::prelude::*; use std::fmt::{Display, Formatter}; -#[pyclass(name = "AggregateFunction", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "AggregateFunction", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyAggregateFunction { aggr: AggregateFunction, diff --git a/src/expr/alias.rs b/src/expr/alias.rs index 40746f200..588c00fdf 100644 --- a/src/expr/alias.rs +++ b/src/expr/alias.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use datafusion::logical_expr::expr::Alias; -#[pyclass(name = "Alias", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Alias", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAlias { alias: Alias, diff --git a/src/expr/analyze.rs b/src/expr/analyze.rs index e8081e95b..c7caeebc8 100644 --- a/src/expr/analyze.rs +++ b/src/expr/analyze.rs @@ -23,7 +23,7 @@ use super::logical_node::LogicalNode; use crate::common::df_schema::PyDFSchema; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Analyze", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Analyze", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyAnalyze { analyze: Analyze, diff --git a/src/expr/between.rs b/src/expr/between.rs index 817f1baae..1f61599a3 100644 --- a/src/expr/between.rs +++ b/src/expr/between.rs @@ -20,7 +20,7 @@ use datafusion::logical_expr::expr::Between; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; -#[pyclass(name = "Between", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Between", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyBetween { between: Between, diff --git a/src/expr/binary_expr.rs b/src/expr/binary_expr.rs index 740299211..94379583c 100644 --- a/src/expr/binary_expr.rs +++ b/src/expr/binary_expr.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::BinaryExpr; use pyo3::prelude::*; -#[pyclass(name = "BinaryExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "BinaryExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyBinaryExpr { expr: BinaryExpr, diff --git a/src/expr/bool_expr.rs b/src/expr/bool_expr.rs index e67e25d74..0d2b051e6 100644 --- a/src/expr/bool_expr.rs +++ b/src/expr/bool_expr.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use super::PyExpr; -#[pyclass(name = "Not", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Not", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyNot { expr: Expr, @@ -51,7 +51,7 @@ impl PyNot { } } -#[pyclass(name = "IsNotNull", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotNull", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotNull { expr: Expr, @@ -81,7 +81,7 @@ impl PyIsNotNull { } } -#[pyclass(name = "IsNull", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNull", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNull { expr: Expr, @@ -111,7 +111,7 @@ impl PyIsNull { } } -#[pyclass(name = "IsTrue", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsTrue", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsTrue { expr: Expr, @@ -141,7 +141,7 @@ impl PyIsTrue { } } -#[pyclass(name = "IsFalse", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsFalse", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsFalse { expr: Expr, @@ -171,7 +171,7 @@ impl PyIsFalse { } } -#[pyclass(name = "IsUnknown", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsUnknown", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsUnknown { expr: Expr, @@ -201,7 +201,7 @@ impl PyIsUnknown { } } -#[pyclass(name = "IsNotTrue", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotTrue", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotTrue { expr: Expr, @@ -231,7 +231,7 @@ impl PyIsNotTrue { } } -#[pyclass(name = "IsNotFalse", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotFalse", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotFalse { expr: Expr, @@ -261,7 +261,7 @@ impl PyIsNotFalse { } } -#[pyclass(name = "IsNotUnknown", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "IsNotUnknown", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyIsNotUnknown { expr: Expr, @@ -291,7 +291,7 @@ impl PyIsNotUnknown { } } -#[pyclass(name = "Negative", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Negative", module = "datafusion.expr", subclass)] #[derive(Clone, Debug)] pub struct PyNegative { expr: Expr, diff --git a/src/expr/case.rs b/src/expr/case.rs index 92e28ba56..1a7369826 100644 --- a/src/expr/case.rs +++ b/src/expr/case.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::Case; use pyo3::prelude::*; -#[pyclass(name = "Case", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Case", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCase { case: Case, diff --git a/src/expr/cast.rs b/src/expr/cast.rs index b8faea634..03e2b8476 100644 --- a/src/expr/cast.rs +++ b/src/expr/cast.rs @@ -19,7 +19,7 @@ use crate::{common::data_type::PyDataType, expr::PyExpr}; use datafusion::logical_expr::{Cast, TryCast}; use pyo3::prelude::*; -#[pyclass(name = "Cast", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Cast", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCast { cast: Cast, diff --git a/src/expr/column.rs b/src/expr/column.rs index 50f316f1c..300079481 100644 --- a/src/expr/column.rs +++ b/src/expr/column.rs @@ -18,7 +18,7 @@ use datafusion::common::Column; use pyo3::prelude::*; -#[pyclass(name = "Column", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Column", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyColumn { pub col: Column, diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs index fe3af2e25..21f538ba0 100644 --- a/src/expr/conditional_expr.rs +++ b/src/expr/conditional_expr.rs @@ -17,38 +17,60 @@ use crate::{errors::PyDataFusionResult, expr::PyExpr}; use datafusion::logical_expr::conditional_expressions::CaseBuilder; +use datafusion::prelude::Expr; use pyo3::prelude::*; -#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass)] +// TODO(tsaucer) replace this all with CaseBuilder after it implements Clone +#[derive(Clone, Debug)] +#[pyclass(name = "CaseBuilder", module = "datafusion.expr", subclass, frozen)] pub struct PyCaseBuilder { - pub case_builder: CaseBuilder, -} - -impl From for CaseBuilder { - fn from(case_builder: PyCaseBuilder) -> Self { - case_builder.case_builder - } -} - -impl From for PyCaseBuilder { - fn from(case_builder: CaseBuilder) -> PyCaseBuilder { - PyCaseBuilder { case_builder } - } + expr: Option, + when: Vec, + then: Vec, } #[pymethods] impl PyCaseBuilder { - fn when(&mut self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { - PyCaseBuilder { - case_builder: self.case_builder.when(when.expr, then.expr), + #[new] + pub fn new(expr: Option) -> Self { + Self { + expr: expr.map(Into::into), + when: vec![], + then: vec![], } } - fn otherwise(&mut self, else_expr: PyExpr) -> PyDataFusionResult { - Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into()) + pub fn when(&self, when: PyExpr, then: PyExpr) -> PyCaseBuilder { + let mut case_builder = self.clone(); + case_builder.when.push(when.into()); + case_builder.then.push(then.into()); + + case_builder } - fn end(&mut self) -> PyDataFusionResult { - Ok(self.case_builder.end()?.clone().into()) + fn otherwise(&self, else_expr: PyExpr) -> PyDataFusionResult { + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + Some(Box::new(else_expr.into())), + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) + } + + fn end(&self) -> PyDataFusionResult { + let case_builder = CaseBuilder::new( + self.expr.clone().map(Box::new), + self.when.clone(), + self.then.clone(), + None, + ); + + let expr = case_builder.end()?; + + Ok(expr.into()) } } diff --git a/src/expr/copy_to.rs b/src/expr/copy_to.rs index c2f7c61d4..422ab77f4 100644 --- a/src/expr/copy_to.rs +++ b/src/expr/copy_to.rs @@ -28,7 +28,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "CopyTo", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CopyTo", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCopyTo { copy: CopyTo, @@ -114,7 +114,7 @@ impl PyCopyTo { } } -#[pyclass(name = "FileType", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "FileType", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyFileType { file_type: Arc, diff --git a/src/expr/create_catalog.rs b/src/expr/create_catalog.rs index d2d2ee8f6..361387894 100644 --- a/src/expr/create_catalog.rs +++ b/src/expr/create_catalog.rs @@ -27,7 +27,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateCatalog", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateCatalog", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateCatalog { create: CreateCatalog, diff --git a/src/expr/create_catalog_schema.rs b/src/expr/create_catalog_schema.rs index e794962f5..cb3be2d30 100644 --- a/src/expr/create_catalog_schema.rs +++ b/src/expr/create_catalog_schema.rs @@ -27,7 +27,12 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateCatalogSchema", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateCatalogSchema", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateCatalogSchema { create: CreateCatalogSchema, diff --git a/src/expr/create_external_table.rs b/src/expr/create_external_table.rs index 3e35af006..920d0d613 100644 --- a/src/expr/create_external_table.rs +++ b/src/expr/create_external_table.rs @@ -29,7 +29,12 @@ use crate::common::df_schema::PyDFSchema; use super::{logical_node::LogicalNode, sort_expr::PySortExpr}; -#[pyclass(name = "CreateExternalTable", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateExternalTable", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateExternalTable { create: CreateExternalTable, diff --git a/src/expr/create_function.rs b/src/expr/create_function.rs index c02ceebb1..1b663b466 100644 --- a/src/expr/create_function.rs +++ b/src/expr/create_function.rs @@ -30,7 +30,7 @@ use super::PyExpr; use crate::common::{data_type::PyDataType, df_schema::PyDFSchema}; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "CreateFunction", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateFunction", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateFunction { create: CreateFunction, @@ -54,21 +54,31 @@ impl Display for PyCreateFunction { } } -#[pyclass(name = "OperateFunctionArg", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "OperateFunctionArg", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyOperateFunctionArg { arg: OperateFunctionArg, } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "Volatility", module = "datafusion.expr")] +#[pyclass(frozen, eq, eq_int, name = "Volatility", module = "datafusion.expr")] pub enum PyVolatility { Immutable, Stable, Volatile, } -#[pyclass(name = "CreateFunctionBody", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateFunctionBody", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateFunctionBody { body: CreateFunctionBody, diff --git a/src/expr/create_index.rs b/src/expr/create_index.rs index 0f4b5011a..7b68df305 100644 --- a/src/expr/create_index.rs +++ b/src/expr/create_index.rs @@ -27,7 +27,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, sort_expr::PySortExpr}; -#[pyclass(name = "CreateIndex", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateIndex", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateIndex { create: CreateIndex, diff --git a/src/expr/create_memory_table.rs b/src/expr/create_memory_table.rs index 37f4d3420..15aaa810b 100644 --- a/src/expr/create_memory_table.rs +++ b/src/expr/create_memory_table.rs @@ -24,7 +24,12 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateMemoryTable", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "CreateMemoryTable", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyCreateMemoryTable { create: CreateMemoryTable, diff --git a/src/expr/create_view.rs b/src/expr/create_view.rs index 718e404d0..49b3b6199 100644 --- a/src/expr/create_view.rs +++ b/src/expr/create_view.rs @@ -24,7 +24,7 @@ use crate::{errors::py_type_err, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "CreateView", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "CreateView", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyCreateView { create: CreateView, diff --git a/src/expr/describe_table.rs b/src/expr/describe_table.rs index 6c48f3c77..315026fef 100644 --- a/src/expr/describe_table.rs +++ b/src/expr/describe_table.rs @@ -28,7 +28,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "DescribeTable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DescribeTable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDescribeTable { describe: DescribeTable, diff --git a/src/expr/distinct.rs b/src/expr/distinct.rs index 889e7099d..5770b849d 100644 --- a/src/expr/distinct.rs +++ b/src/expr/distinct.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Distinct", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Distinct", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDistinct { distinct: Distinct, diff --git a/src/expr/dml.rs b/src/expr/dml.rs index 251e336cc..4437a9de9 100644 --- a/src/expr/dml.rs +++ b/src/expr/dml.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "DmlStatement", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DmlStatement", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDmlStatement { dml: DmlStatement, diff --git a/src/expr/drop_catalog_schema.rs b/src/expr/drop_catalog_schema.rs index b4a4c521c..7008bcd24 100644 --- a/src/expr/drop_catalog_schema.rs +++ b/src/expr/drop_catalog_schema.rs @@ -28,7 +28,12 @@ use crate::common::df_schema::PyDFSchema; use super::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropCatalogSchema", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "DropCatalogSchema", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyDropCatalogSchema { drop: DropCatalogSchema, diff --git a/src/expr/drop_function.rs b/src/expr/drop_function.rs index fca9eb94b..42ad3e1fe 100644 --- a/src/expr/drop_function.rs +++ b/src/expr/drop_function.rs @@ -27,7 +27,7 @@ use super::logical_node::LogicalNode; use crate::common::df_schema::PyDFSchema; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropFunction", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropFunction", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropFunction { drop: DropFunction, diff --git a/src/expr/drop_table.rs b/src/expr/drop_table.rs index 3f442539a..6ff4f01c4 100644 --- a/src/expr/drop_table.rs +++ b/src/expr/drop_table.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "DropTable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropTable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropTable { drop: DropTable, diff --git a/src/expr/drop_view.rs b/src/expr/drop_view.rs index 6196c8bb5..b2aff4e9b 100644 --- a/src/expr/drop_view.rs +++ b/src/expr/drop_view.rs @@ -28,7 +28,7 @@ use crate::common::df_schema::PyDFSchema; use super::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "DropView", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "DropView", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDropView { drop: DropView, diff --git a/src/expr/empty_relation.rs b/src/expr/empty_relation.rs index 758213423..797a8c02d 100644 --- a/src/expr/empty_relation.rs +++ b/src/expr/empty_relation.rs @@ -22,7 +22,7 @@ use std::fmt::{self, Display, Formatter}; use super::logical_node::LogicalNode; -#[pyclass(name = "EmptyRelation", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "EmptyRelation", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyEmptyRelation { empty: EmptyRelation, diff --git a/src/expr/exists.rs b/src/expr/exists.rs index 693357836..392bfcb9e 100644 --- a/src/expr/exists.rs +++ b/src/expr/exists.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::subquery::PySubquery; -#[pyclass(name = "Exists", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Exists", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExists { exists: Exists, diff --git a/src/expr/explain.rs b/src/expr/explain.rs index fc02fe2b5..71b7b2c13 100644 --- a/src/expr/explain.rs +++ b/src/expr/explain.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, errors::py_type_err, sql::logical::Py use super::logical_node::LogicalNode; -#[pyclass(name = "Explain", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Explain", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExplain { explain: Explain, diff --git a/src/expr/extension.rs b/src/expr/extension.rs index 1e3fbb199..7d913ff8c 100644 --- a/src/expr/extension.rs +++ b/src/expr/extension.rs @@ -22,7 +22,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Extension", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Extension", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExtension { pub node: Extension, diff --git a/src/expr/filter.rs b/src/expr/filter.rs index 4fcb600cd..76338d139 100644 --- a/src/expr/filter.rs +++ b/src/expr/filter.rs @@ -24,7 +24,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Filter", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Filter", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyFilter { filter: Filter, diff --git a/src/expr/grouping_set.rs b/src/expr/grouping_set.rs index 63a1c0b50..107dd9370 100644 --- a/src/expr/grouping_set.rs +++ b/src/expr/grouping_set.rs @@ -18,7 +18,7 @@ use datafusion::logical_expr::GroupingSet; use pyo3::prelude::*; -#[pyclass(name = "GroupingSet", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "GroupingSet", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyGroupingSet { grouping_set: GroupingSet, diff --git a/src/expr/in_list.rs b/src/expr/in_list.rs index 5dfd8d8eb..e2e6d7832 100644 --- a/src/expr/in_list.rs +++ b/src/expr/in_list.rs @@ -19,7 +19,7 @@ use crate::expr::PyExpr; use datafusion::logical_expr::expr::InList; use pyo3::prelude::*; -#[pyclass(name = "InList", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "InList", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyInList { in_list: InList, diff --git a/src/expr/in_subquery.rs b/src/expr/in_subquery.rs index 306b68a6e..6d4a38e49 100644 --- a/src/expr/in_subquery.rs +++ b/src/expr/in_subquery.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::{subquery::PySubquery, PyExpr}; -#[pyclass(name = "InSubquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "InSubquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyInSubquery { in_subquery: InSubquery, diff --git a/src/expr/indexed_field.rs b/src/expr/indexed_field.rs index a22dc6b27..1dfa0ed2f 100644 --- a/src/expr/indexed_field.rs +++ b/src/expr/indexed_field.rs @@ -22,7 +22,7 @@ use std::fmt::{Display, Formatter}; use super::literal::PyLiteral; -#[pyclass(name = "GetIndexedField", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "GetIndexedField", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyGetIndexedField { indexed_field: GetIndexedField, diff --git a/src/expr/join.rs b/src/expr/join.rs index 7b7e0d9dd..3fde874d5 100644 --- a/src/expr/join.rs +++ b/src/expr/join.rs @@ -25,7 +25,7 @@ use crate::expr::{logical_node::LogicalNode, PyExpr}; use crate::sql::logical::PyLogicalPlan; #[derive(Debug, Clone, PartialEq, Eq, Hash)] -#[pyclass(name = "JoinType", module = "datafusion.expr")] +#[pyclass(frozen, name = "JoinType", module = "datafusion.expr")] pub struct PyJoinType { join_type: JoinType, } @@ -60,7 +60,7 @@ impl Display for PyJoinType { } #[derive(Debug, Clone, Copy)] -#[pyclass(name = "JoinConstraint", module = "datafusion.expr")] +#[pyclass(frozen, name = "JoinConstraint", module = "datafusion.expr")] pub struct PyJoinConstraint { join_constraint: JoinConstraint, } @@ -87,7 +87,7 @@ impl PyJoinConstraint { } } -#[pyclass(name = "Join", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Join", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyJoin { join: Join, diff --git a/src/expr/like.rs b/src/expr/like.rs index f180f5d4c..0a36dcd92 100644 --- a/src/expr/like.rs +++ b/src/expr/like.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use crate::expr::PyExpr; -#[pyclass(name = "Like", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Like", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLike { like: Like, @@ -79,7 +79,7 @@ impl PyLike { } } -#[pyclass(name = "ILike", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ILike", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyILike { like: Like, @@ -137,7 +137,7 @@ impl PyILike { } } -#[pyclass(name = "SimilarTo", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SimilarTo", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySimilarTo { like: Like, diff --git a/src/expr/limit.rs b/src/expr/limit.rs index 92552814e..cf6971fb3 100644 --- a/src/expr/limit.rs +++ b/src/expr/limit.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Limit", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Limit", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyLimit { limit: Limit, diff --git a/src/expr/literal.rs b/src/expr/literal.rs index 561242c9c..8a589b55a 100644 --- a/src/expr/literal.rs +++ b/src/expr/literal.rs @@ -19,7 +19,7 @@ use crate::errors::PyDataFusionError; use datafusion::{common::ScalarValue, logical_expr::expr::FieldMetadata}; use pyo3::{prelude::*, IntoPyObjectExt}; -#[pyclass(name = "Literal", module = "datafusion.expr", subclass)] +#[pyclass(name = "Literal", module = "datafusion.expr", subclass, frozen)] #[derive(Clone)] pub struct PyLiteral { pub value: ScalarValue, @@ -71,7 +71,7 @@ impl PyLiteral { extract_scalar_value!(self, Float64) } - pub fn value_decimal128(&mut self) -> PyResult<(Option, u8, i8)> { + pub fn value_decimal128(&self) -> PyResult<(Option, u8, i8)> { match &self.value { ScalarValue::Decimal128(value, precision, scale) => Ok((*value, *precision, *scale)), other => Err(unexpected_literal_value(other)), @@ -122,7 +122,7 @@ impl PyLiteral { extract_scalar_value!(self, Time64Nanosecond) } - pub fn value_timestamp(&mut self) -> PyResult<(Option, Option)> { + pub fn value_timestamp(&self) -> PyResult<(Option, Option)> { match &self.value { ScalarValue::TimestampNanosecond(iv, tz) | ScalarValue::TimestampMicrosecond(iv, tz) diff --git a/src/expr/placeholder.rs b/src/expr/placeholder.rs index 4ac2c47e3..268263d41 100644 --- a/src/expr/placeholder.rs +++ b/src/expr/placeholder.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use crate::common::data_type::PyDataType; -#[pyclass(name = "Placeholder", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Placeholder", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPlaceholder { placeholder: Placeholder, diff --git a/src/expr/projection.rs b/src/expr/projection.rs index b5a9ef34a..b2d5db79b 100644 --- a/src/expr/projection.rs +++ b/src/expr/projection.rs @@ -25,7 +25,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::PyExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Projection", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Projection", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyProjection { pub projection: Projection, diff --git a/src/expr/recursive_query.rs b/src/expr/recursive_query.rs index 2517b7417..fe047315e 100644 --- a/src/expr/recursive_query.rs +++ b/src/expr/recursive_query.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "RecursiveQuery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "RecursiveQuery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyRecursiveQuery { query: RecursiveQuery, diff --git a/src/expr/repartition.rs b/src/expr/repartition.rs index 48b5e7041..ee6d1dc45 100644 --- a/src/expr/repartition.rs +++ b/src/expr/repartition.rs @@ -24,13 +24,13 @@ use crate::{errors::py_type_err, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "Repartition", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Repartition", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyRepartition { repartition: Repartition, } -#[pyclass(name = "Partitioning", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Partitioning", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPartitioning { partitioning: Partitioning, diff --git a/src/expr/scalar_subquery.rs b/src/expr/scalar_subquery.rs index 9d35f28a9..e58d66e19 100644 --- a/src/expr/scalar_subquery.rs +++ b/src/expr/scalar_subquery.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use super::subquery::PySubquery; -#[pyclass(name = "ScalarSubquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ScalarSubquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyScalarSubquery { subquery: Subquery, diff --git a/src/expr/scalar_variable.rs b/src/expr/scalar_variable.rs index 7b50ba241..f3c128a4c 100644 --- a/src/expr/scalar_variable.rs +++ b/src/expr/scalar_variable.rs @@ -20,7 +20,7 @@ use pyo3::prelude::*; use crate::common::data_type::PyDataType; -#[pyclass(name = "ScalarVariable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "ScalarVariable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyScalarVariable { data_type: DataType, diff --git a/src/expr/signature.rs b/src/expr/signature.rs index e85763555..e2c23dce9 100644 --- a/src/expr/signature.rs +++ b/src/expr/signature.rs @@ -19,7 +19,7 @@ use datafusion::logical_expr::{TypeSignature, Volatility}; use pyo3::prelude::*; #[allow(dead_code)] -#[pyclass(name = "Signature", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Signature", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySignature { type_signature: TypeSignature, diff --git a/src/expr/sort.rs b/src/expr/sort.rs index 79a8aee50..d5ea07fdd 100644 --- a/src/expr/sort.rs +++ b/src/expr/sort.rs @@ -25,7 +25,7 @@ use crate::expr::logical_node::LogicalNode; use crate::expr::sort_expr::PySortExpr; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Sort", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Sort", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySort { sort: Sort, diff --git a/src/expr/sort_expr.rs b/src/expr/sort_expr.rs index e2df6b963..3f279027e 100644 --- a/src/expr/sort_expr.rs +++ b/src/expr/sort_expr.rs @@ -20,7 +20,7 @@ use datafusion::logical_expr::SortExpr; use pyo3::prelude::*; use std::fmt::{self, Display, Formatter}; -#[pyclass(name = "SortExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SortExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySortExpr { pub(crate) sort: SortExpr, diff --git a/src/expr/statement.rs b/src/expr/statement.rs index 83774cda1..1ea4f9f7f 100644 --- a/src/expr/statement.rs +++ b/src/expr/statement.rs @@ -25,7 +25,12 @@ use crate::{common::data_type::PyDataType, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "TransactionStart", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "TransactionStart", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyTransactionStart { transaction_start: TransactionStart, @@ -56,7 +61,13 @@ impl LogicalNode for PyTransactionStart { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TransactionAccessMode", module = "datafusion.expr")] +#[pyclass( + frozen, + eq, + eq_int, + name = "TransactionAccessMode", + module = "datafusion.expr" +)] pub enum PyTransactionAccessMode { ReadOnly, ReadWrite, @@ -84,6 +95,7 @@ impl TryFrom for TransactionAccessMode { #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[pyclass( + frozen, eq, eq_int, name = "TransactionIsolationLevel", @@ -161,7 +173,7 @@ impl PyTransactionStart { } } -#[pyclass(name = "TransactionEnd", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "TransactionEnd", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTransactionEnd { transaction_end: TransactionEnd, @@ -192,7 +204,13 @@ impl LogicalNode for PyTransactionEnd { } #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[pyclass(eq, eq_int, name = "TransactionConclusion", module = "datafusion.expr")] +#[pyclass( + frozen, + eq, + eq_int, + name = "TransactionConclusion", + module = "datafusion.expr" +)] pub enum PyTransactionConclusion { Commit, Rollback, @@ -236,7 +254,7 @@ impl PyTransactionEnd { } } -#[pyclass(name = "SetVariable", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SetVariable", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySetVariable { set_variable: SetVariable, @@ -284,7 +302,7 @@ impl PySetVariable { } } -#[pyclass(name = "Prepare", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Prepare", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyPrepare { prepare: Prepare, @@ -352,7 +370,7 @@ impl PyPrepare { } } -#[pyclass(name = "Execute", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Execute", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyExecute { execute: Execute, @@ -409,7 +427,7 @@ impl PyExecute { } } -#[pyclass(name = "Deallocate", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Deallocate", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyDeallocate { deallocate: Deallocate, diff --git a/src/expr/subquery.rs b/src/expr/subquery.rs index 77f56f9a9..785cf7d1a 100644 --- a/src/expr/subquery.rs +++ b/src/expr/subquery.rs @@ -24,7 +24,7 @@ use crate::sql::logical::PyLogicalPlan; use super::logical_node::LogicalNode; -#[pyclass(name = "Subquery", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Subquery", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySubquery { subquery: Subquery, diff --git a/src/expr/subquery_alias.rs b/src/expr/subquery_alias.rs index 3302e7f23..ab1229bfe 100644 --- a/src/expr/subquery_alias.rs +++ b/src/expr/subquery_alias.rs @@ -24,7 +24,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::logical_node::LogicalNode; -#[pyclass(name = "SubqueryAlias", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "SubqueryAlias", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PySubqueryAlias { subquery_alias: SubqueryAlias, diff --git a/src/expr/table_scan.rs b/src/expr/table_scan.rs index 329964687..34a140df3 100644 --- a/src/expr/table_scan.rs +++ b/src/expr/table_scan.rs @@ -24,7 +24,7 @@ use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; use crate::{common::df_schema::PyDFSchema, expr::PyExpr}; -#[pyclass(name = "TableScan", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "TableScan", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyTableScan { table_scan: TableScan, diff --git a/src/expr/union.rs b/src/expr/union.rs index e0b221398..b7b589650 100644 --- a/src/expr/union.rs +++ b/src/expr/union.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Union", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Union", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnion { union_: Union, diff --git a/src/expr/unnest.rs b/src/expr/unnest.rs index c8833347f..7ed7919b1 100644 --- a/src/expr/unnest.rs +++ b/src/expr/unnest.rs @@ -23,7 +23,7 @@ use crate::common::df_schema::PyDFSchema; use crate::expr::logical_node::LogicalNode; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Unnest", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Unnest", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnnest { unnest_: Unnest, diff --git a/src/expr/unnest_expr.rs b/src/expr/unnest_expr.rs index 634186ed8..2cdf46a59 100644 --- a/src/expr/unnest_expr.rs +++ b/src/expr/unnest_expr.rs @@ -21,7 +21,7 @@ use std::fmt::{self, Display, Formatter}; use super::PyExpr; -#[pyclass(name = "UnnestExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "UnnestExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyUnnestExpr { unnest: Unnest, diff --git a/src/expr/values.rs b/src/expr/values.rs index fb2692230..63d94ce00 100644 --- a/src/expr/values.rs +++ b/src/expr/values.rs @@ -25,7 +25,7 @@ use crate::{common::df_schema::PyDFSchema, sql::logical::PyLogicalPlan}; use super::{logical_node::LogicalNode, PyExpr}; -#[pyclass(name = "Values", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "Values", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyValues { values: Values, diff --git a/src/expr/window.rs b/src/expr/window.rs index 77ecb71aa..2723007ec 100644 --- a/src/expr/window.rs +++ b/src/expr/window.rs @@ -30,13 +30,13 @@ use std::fmt::{self, Display, Formatter}; use super::py_expr_list; -#[pyclass(name = "WindowExpr", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "WindowExpr", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyWindowExpr { window: Window, } -#[pyclass(name = "WindowFrame", module = "datafusion.expr", subclass)] +#[pyclass(frozen, name = "WindowFrame", module = "datafusion.expr", subclass)] #[derive(Clone)] pub struct PyWindowFrame { window_frame: WindowFrame, @@ -54,7 +54,12 @@ impl From for PyWindowFrame { } } -#[pyclass(name = "WindowFrameBound", module = "datafusion.expr", subclass)] +#[pyclass( + frozen, + name = "WindowFrameBound", + module = "datafusion.expr", + subclass +)] #[derive(Clone)] pub struct PyWindowFrameBound { frame_bound: WindowFrameBound, diff --git a/src/functions.rs b/src/functions.rs index e92cf053f..5956b67cf 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -230,17 +230,13 @@ fn col(name: &str) -> PyResult { /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn case(expr: PyExpr) -> PyResult { - Ok(PyCaseBuilder { - case_builder: datafusion::logical_expr::case(expr.expr), - }) + Ok(PyCaseBuilder::new(Some(expr))) } /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. #[pyfunction] fn when(when: PyExpr, then: PyExpr) -> PyResult { - Ok(PyCaseBuilder { - case_builder: datafusion::logical_expr::when(when.expr, then.expr), - }) + Ok(PyCaseBuilder::new(None).when(when, then)) } /// Helper function to find the appropriate window function. diff --git a/src/physical_plan.rs b/src/physical_plan.rs index 49db643e1..4994b0114 100644 --- a/src/physical_plan.rs +++ b/src/physical_plan.rs @@ -24,7 +24,7 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; use crate::{context::PySessionContext, errors::PyDataFusionResult}; -#[pyclass(name = "ExecutionPlan", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ExecutionPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyExecutionPlan { pub plan: Arc, diff --git a/src/record_batch.rs b/src/record_batch.rs index a85f05423..c3658cf4b 100644 --- a/src/record_batch.rs +++ b/src/record_batch.rs @@ -28,7 +28,7 @@ use pyo3::prelude::*; use pyo3::{pyclass, pymethods, PyObject, PyResult, Python}; use tokio::sync::Mutex; -#[pyclass(name = "RecordBatch", module = "datafusion", subclass)] +#[pyclass(name = "RecordBatch", module = "datafusion", subclass, frozen)] pub struct PyRecordBatch { batch: RecordBatch, } @@ -46,7 +46,7 @@ impl From for PyRecordBatch { } } -#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass)] +#[pyclass(name = "RecordBatchStream", module = "datafusion", subclass, frozen)] pub struct PyRecordBatchStream { stream: Arc>, } @@ -61,12 +61,12 @@ impl PyRecordBatchStream { #[pymethods] impl PyRecordBatchStream { - fn next(&mut self, py: Python) -> PyResult { + fn next(&self, py: Python) -> PyResult { let stream = self.stream.clone(); wait_for_future(py, next_stream(stream, true))? } - fn __next__(&mut self, py: Python) -> PyResult { + fn __next__(&self, py: Python) -> PyResult { self.next(py) } diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 97d320470..47ea39fdc 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -63,7 +63,7 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes}; use crate::expr::logical_node::LogicalNode; -#[pyclass(name = "LogicalPlan", module = "datafusion", subclass)] +#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyLogicalPlan { pub(crate) plan: Arc, diff --git a/src/store.rs b/src/store.rs index 1e5fab472..998681854 100644 --- a/src/store.rs +++ b/src/store.rs @@ -36,7 +36,12 @@ pub enum StorageContexts { HTTP(PyHttpContext), } -#[pyclass(name = "LocalFileSystem", module = "datafusion.store", subclass)] +#[pyclass( + frozen, + name = "LocalFileSystem", + module = "datafusion.store", + subclass +)] #[derive(Debug, Clone)] pub struct PyLocalFileSystemContext { pub inner: Arc, @@ -62,7 +67,7 @@ impl PyLocalFileSystemContext { } } -#[pyclass(name = "MicrosoftAzure", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "MicrosoftAzure", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyMicrosoftAzureContext { pub inner: Arc, @@ -134,7 +139,7 @@ impl PyMicrosoftAzureContext { } } -#[pyclass(name = "GoogleCloud", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "GoogleCloud", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyGoogleCloudContext { pub inner: Arc, @@ -164,7 +169,7 @@ impl PyGoogleCloudContext { } } -#[pyclass(name = "AmazonS3", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "AmazonS3", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyAmazonS3Context { pub inner: Arc, @@ -223,7 +228,7 @@ impl PyAmazonS3Context { } } -#[pyclass(name = "Http", module = "datafusion.store", subclass)] +#[pyclass(frozen, name = "Http", module = "datafusion.store", subclass)] #[derive(Debug, Clone)] pub struct PyHttpContext { pub url: String, diff --git a/src/substrait.rs b/src/substrait.rs index f1936b05e..291892cf8 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -27,7 +27,7 @@ use datafusion_substrait::serializer; use datafusion_substrait::substrait::proto::Plan; use prost::Message; -#[pyclass(name = "Plan", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Plan", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PyPlan { pub plan: Plan, @@ -59,7 +59,7 @@ impl From for PyPlan { /// A PySubstraitSerializer is a representation of a Serializer that is capable of both serializing /// a `LogicalPlan` instance to Substrait Protobuf bytes and also deserialize Substrait Protobuf bytes /// to a valid `LogicalPlan` instance. -#[pyclass(name = "Serde", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Serde", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitSerializer; @@ -112,7 +112,7 @@ impl PySubstraitSerializer { } } -#[pyclass(name = "Producer", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Producer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitProducer; @@ -129,7 +129,7 @@ impl PySubstraitProducer { } } -#[pyclass(name = "Consumer", module = "datafusion.substrait", subclass)] +#[pyclass(frozen, name = "Consumer", module = "datafusion.substrait", subclass)] #[derive(Debug, Clone)] pub struct PySubstraitConsumer; diff --git a/src/udaf.rs b/src/udaf.rs index 78f4e2b0c..eab4581df 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -155,7 +155,7 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { } /// Represents an AggregateUDF -#[pyclass(name = "AggregateUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyAggregateUDF { pub(crate) function: AggregateUDF, diff --git a/src/udf.rs b/src/udf.rs index de1e3f18c..a9249d6c8 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -81,7 +81,7 @@ fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { } /// Represents a PyScalarUDF -#[pyclass(name = "ScalarUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyScalarUDF { pub(crate) function: ScalarUDF, diff --git a/src/udtf.rs b/src/udtf.rs index db16d6c05..55f306b17 100644 --- a/src/udtf.rs +++ b/src/udtf.rs @@ -31,7 +31,7 @@ use pyo3::exceptions::PyNotImplementedError; use pyo3::types::{PyCapsule, PyTuple}; /// Represents a user defined table function -#[pyclass(name = "TableFunction", module = "datafusion")] +#[pyclass(frozen, name = "TableFunction", module = "datafusion")] #[derive(Debug, Clone)] pub struct PyTableFunction { pub(crate) name: String, diff --git a/src/udwf.rs b/src/udwf.rs index 70a66e38f..ceeaa0ef1 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -210,7 +210,7 @@ pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFac } /// Represents an WindowUDF -#[pyclass(name = "WindowUDF", module = "datafusion", subclass)] +#[pyclass(frozen, name = "WindowUDF", module = "datafusion", subclass)] #[derive(Debug, Clone)] pub struct PyWindowUDF { pub(crate) function: WindowUDF, diff --git a/src/unparser/dialect.rs b/src/unparser/dialect.rs index caeef9949..5df0a0c2e 100644 --- a/src/unparser/dialect.rs +++ b/src/unparser/dialect.rs @@ -22,7 +22,7 @@ use datafusion::sql::unparser::dialect::{ }; use pyo3::prelude::*; -#[pyclass(name = "Dialect", module = "datafusion.unparser", subclass)] +#[pyclass(frozen, name = "Dialect", module = "datafusion.unparser", subclass)] #[derive(Clone)] pub struct PyDialect { pub dialect: Arc, diff --git a/src/unparser/mod.rs b/src/unparser/mod.rs index b4b0fed10..f234345a7 100644 --- a/src/unparser/mod.rs +++ b/src/unparser/mod.rs @@ -25,7 +25,7 @@ use pyo3::{exceptions::PyValueError, prelude::*}; use crate::sql::logical::PyLogicalPlan; -#[pyclass(name = "Unparser", module = "datafusion.unparser", subclass)] +#[pyclass(frozen, name = "Unparser", module = "datafusion.unparser", subclass)] #[derive(Clone)] pub struct PyUnparser { dialect: Arc,