diff --git a/duckdb/__init__.pyi b/duckdb/__init__.pyi index adf142dd..8f27e5e3 100644 --- a/duckdb/__init__.pyi +++ b/duckdb/__init__.pyi @@ -415,7 +415,7 @@ class DuckDBPyRelation: def variance(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ... - def arrow(self, batch_size: int = ...) -> pyarrow.lib.Table: ... + def arrow(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ... def create(self, table_name: str) -> None: ... def create_view(self, view_name: str, replace: bool = ...) -> DuckDBPyRelation: ... @@ -448,6 +448,7 @@ class DuckDBPyRelation: def pl(self, rows_per_batch: int = ..., connection: DuckDBPyConnection = ...) -> polars.DataFrame: ... def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... def record_batch(self, batch_size: int = ...) -> pyarrow.lib.RecordBatchReader: ... + def fetch_record_batch(self, rows_per_batch: int = 1000000, *, connection: DuckDBPyConnection = ...) -> pyarrow.lib.RecordBatchReader: ... def select_types(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def select_dtypes(self, types: List[Union[str, DuckDBPyType]]) -> DuckDBPyRelation: ... def set_alias(self, alias: str) -> DuckDBPyRelation: ... diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index b8a4698b..a81a423b 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -75,7 +75,7 @@ def toArrow(self) -> "pa.Table": age: [[2,5]] name: [["Alice","Bob"]] """ - return self.relation.arrow() + return self.relation.to_arrow_table() def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index a93a54b5..7992cc17 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -61,7 +61,7 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", + .def("arrow", &DuckDBPyRelation::ToRecordBatch, "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) @@ -78,10 +78,18 @@ static void InitializeConsumers(py::class_ &m) { )"; m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs, py::arg("requested_schema") = py::none()); - m.def("record_batch", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000); + m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) + .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) + .def("record_batch", + [](pybind11::object &self, idx_t rows_per_batch) + { + PyErr_WarnEx(PyExc_DeprecationWarning, + "record_batch() is deprecated, use fetch_record_batch() instead.", + 0); + return self.attr("fetch_record_batch")(rows_per_batch); + }, py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { diff --git a/tests/pytest.ini b/tests/pytest.ini index 5dd3c306..0c17afd5 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -3,6 +3,7 @@ filterwarnings = error ignore::UserWarning + ignore::DeprecationWarning # Jupyter is throwing DeprecationWarnings ignore:function ham\(\) is deprecated:DeprecationWarning # Pyspark is throwing these warnings