Skip to content

Commit fb7670b

Browse files
authored
fix: Handle duckdb breaking change in arrow (#3119)
1 parent 5d5aa7c commit fb7670b

File tree

8 files changed

+30
-9
lines changed

8 files changed

+30
-9
lines changed

docs/generating_sql.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ There are several ways to find out.
2424

2525
## Via DuckDB
2626

27-
You can also generate SQL directly from DuckDB.
27+
You can generate SQL directly from DuckDB.
2828

2929
```python exec="1" source="above" session="generating-sql" result="sql"
3030
import duckdb

narwhals/_duckdb/dataframe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
not_implemented,
2828
parse_columns_to_drop,
2929
requires,
30+
to_pyarrow_table,
3031
zip_strict,
3132
)
3233
from narwhals.dependencies import get_duckdb
@@ -138,7 +139,7 @@ def collect(
138139
from narwhals._arrow.dataframe import ArrowDataFrame
139140

140141
return ArrowDataFrame(
141-
self.native.arrow(),
142+
to_pyarrow_table(self.native.arrow()),
142143
validate_backend_version=True,
143144
version=self._version,
144145
validate_column_names=True,
@@ -255,7 +256,7 @@ def to_pandas(self) -> pd.DataFrame:
255256

256257
def to_arrow(self) -> pa.Table:
257258
# only if version is v1, keep around for backcompat
258-
return self.native.arrow()
259+
return self.lazy().collect(Implementation.PYARROW).native # type: ignore[no-any-return]
259260

260261
def _with_version(self, version: Version) -> Self:
261262
return self.__class__(self.native, version=version)

narwhals/_ibis/dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Version,
1616
not_implemented,
1717
parse_columns_to_drop,
18+
to_pyarrow_table,
1819
zip_strict,
1920
)
2021
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError
@@ -109,7 +110,7 @@ def collect(
109110
from narwhals._arrow.dataframe import ArrowDataFrame
110111

111112
return ArrowDataFrame(
112-
self.native.to_pyarrow(),
113+
to_pyarrow_table(self.native.to_pyarrow()),
113114
validate_backend_version=True,
114115
version=self._version,
115116
validate_column_names=True,

narwhals/_spark_like/dataframe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
generate_temporary_column_name,
2323
not_implemented,
2424
parse_columns_to_drop,
25+
to_pyarrow_table,
2526
zip_strict,
2627
)
2728
from narwhals.exceptions import InvalidOperationError
@@ -184,7 +185,7 @@ def _collect_to_arrow(self) -> pa.Table:
184185
pa_schema = self._to_arrow_schema()
185186
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
186187
else:
187-
return self.native.toArrow()
188+
return to_pyarrow_table(self.native.toArrow())
188189

189190
def _iter_columns(self) -> Iterator[Column]:
190191
for col in self.columns:

narwhals/_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,3 +2148,11 @@ def __get__(
21482148
def __get__(self, instance: LazyFrame[Any], owner: Any) -> _LazyAllowedImpl: ...
21492149
def __get__(self, instance: Narwhals[Any] | None, owner: Any) -> Any:
21502150
return self if instance is None else instance._compliant._implementation
2151+
2152+
2153+
def to_pyarrow_table(tbl: pa.Table | pa.RecordBatchReader) -> pa.Table:
2154+
import pyarrow as pa # ignore-banned-import
2155+
2156+
if isinstance(tbl, pa.RecordBatchReader): # pragma: no cover
2157+
return pa.Table.from_batches(tbl)
2158+
return tbl

tests/expr_and_series/reduction_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,15 @@ def test_empty_scalar_reduction_select(constructor: Constructor) -> None:
9191
assert_equal_data(result, expected)
9292

9393

94-
def test_empty_scalar_reduction_with_columns(constructor: Constructor) -> None:
94+
def test_empty_scalar_reduction_with_columns(
95+
constructor: Constructor, request: pytest.FixtureRequest
96+
) -> None:
9597
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3):
9698
pytest.skip()
99+
if any(
100+
x in str(constructor) for x in ("duckdb", "sqlframe", "ibis")
101+
) and DUCKDB_VERSION >= (1, 4):
102+
request.applymarker(pytest.mark.xfail)
97103
from itertools import chain
98104

99105
data = {

tests/frame/collect_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import narwhals as nw
88
from narwhals._utils import Implementation
99
from narwhals.dependencies import get_cudf, get_modin, get_polars
10-
from tests.utils import POLARS_VERSION, Constructor, assert_equal_data
10+
from tests.utils import DUCKDB_VERSION, POLARS_VERSION, Constructor, assert_equal_data
1111

1212
if TYPE_CHECKING:
1313
from narwhals._typing import Arrow, Dask, IntoBackend, Modin, Pandas, Polars
@@ -163,7 +163,11 @@ def test_collect_with_kwargs(constructor: Constructor) -> None:
163163
assert_equal_data(result, expected)
164164

165165

166-
def test_collect_empty(constructor: Constructor) -> None:
166+
def test_collect_empty(constructor: Constructor, request: pytest.FixtureRequest) -> None:
167+
if any(
168+
x in str(constructor) for x in ("duckdb", "sqlframe", "ibis")
169+
) and DUCKDB_VERSION >= (1, 4):
170+
request.applymarker(pytest.mark.xfail)
167171
df = nw.from_native(constructor({"a": [1, 2, 3]}))
168172
lf = df.filter(nw.col("a").is_null()).with_columns(b=nw.lit(None)).lazy()
169173
result = lf.collect()

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def assert_equal_data(result: Any, expected: Mapping[str, Any]) -> None:
8888
and result._compliant_frame._implementation.is_spark_like()
8989
)
9090
if is_duckdb:
91-
result = from_native(result.to_native().arrow())
91+
result = from_native(result.collect("pyarrow"))
9292
if is_ibis:
9393
result = from_native(result.to_native().to_pyarrow())
9494
if hasattr(result, "collect"):

0 commit comments

Comments
 (0)