Skip to content

Commit c1ed4fa

Browse files
Fix Polars expr pushdown (duckdb#102)
See duckdb#98 We have our own Polars IO plugin to create lazy polars dataframes. We try to push as many predicates down into DuckDB as we can, for which we try to map Polars expressions (including datatypes) to SQL expressions. To do this we depend on Polar's `Expr.meta.serialize`. None of this is very stable. Polars IO source plugins are marked `@unstable` and `Expr.meta.serialize` says "Serialization is not stable across Polars versions". In this case the problems seems to come from Polars requiring an explicit scale to be set for decimals ([this pr](pola-rs/polars#24542)). The serialized format seems to have changed into: ```json { "expr": { "Literal": { "Scalar": { "Decimal": [ 1, 38, // This now includes the scale 0 ] } } }, "dtype": { "Literal": { "Decimal": [ 38, // this was already there 0 ] } }, "options": "Strict" } ``` Interestingly, even if we explicitly set precision to e.g. 20, the relevant part of the serialized expression looks as follows: ```json { "expr": { "Literal": { "Scalar": { "Decimal": [ 1, 38, // Still 38? 0 ] } } }, "dtype": { "Literal": { "Decimal": [ 20, // This is correct 0 ] } }, "options": "Strict" } ``` This PR allows for both a 2 and 3 item list for decimals.
2 parents 511d5e2 + dca3ccb commit c1ed4fa

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

duckdb/polars_io.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations # noqa: D100
22

3+
import contextlib
34
import datetime
45
import json
56
import typing
@@ -176,9 +177,12 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
176177
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
177178
decimal_value = value["Decimal"]
178179
assert isinstance(decimal_value, list), (
179-
f"A {dtype} should be a two member list but got {type(decimal_value)}"
180+
f"A {dtype} should be a two or three member list but got {type(decimal_value)}"
180181
)
181-
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]))
182+
assert 2 <= len(decimal_value) <= 3, (
183+
f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list"
184+
)
185+
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1]))
182186

183187
# Datetime with microseconds since epoch
184188
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
@@ -260,7 +264,8 @@ def source_generator(
260264
relation_final = relation_final.limit(n_rows)
261265
if predicate is not None:
262266
# We have a predicate, if possible, we push it down to DuckDB
263-
duck_predicate = _predicate_to_expression(predicate)
267+
with contextlib.suppress(AssertionError, KeyError):
268+
duck_predicate = _predicate_to_expression(predicate)
264269
# Try to pushdown filter, if one exists
265270
if duck_predicate is not None:
266271
relation_final = relation_final.filter(duck_predicate)

tests/fast/arrow/test_polars.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import json
23

34
import pytest
45

@@ -8,7 +9,7 @@
89
arrow = pytest.importorskip("pyarrow")
910
pl_testing = pytest.importorskip("polars.testing")
1011

11-
from duckdb.polars_io import _predicate_to_expression # noqa: E402
12+
from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402
1213

1314

1415
def valid_filter(filter):
@@ -175,7 +176,7 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor):
175176
"UBIGINT",
176177
"FLOAT",
177178
"DOUBLE",
178-
# "HUGEINT",
179+
"HUGEINT",
179180
"DECIMAL(4,1)",
180181
"DECIMAL(9,1)",
181182
"DECIMAL(18,4)",
@@ -605,3 +606,50 @@ def test_polars_lazy_many_batches(self, duckdb_cursor):
605606
correct = duckdb_cursor.execute("FROM t").fetchall()
606607

607608
assert res == correct
609+
610+
def test_invalid_expr_json(self):
611+
bad_key_expr = """
612+
{
613+
"BinaryExpr": {
614+
"left": { "Column": "foo" },
615+
"middle": "Gt",
616+
"right": { "Literal": { "Int": 5 } }
617+
}
618+
}
619+
"""
620+
with pytest.raises(KeyError, match="'op'"):
621+
_pl_tree_to_sql(json.loads(bad_key_expr))
622+
623+
bad_type_expr = """
624+
{
625+
"BinaryExpr": {
626+
"left": { "Column": [ "foo" ] },
627+
"op": "Gt",
628+
"right": { "Literal": { "Int": 5 } }
629+
}
630+
}
631+
"""
632+
with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"):
633+
_pl_tree_to_sql(json.loads(bad_type_expr))
634+
635+
def test_decimal_scale(self):
636+
scalar_decimal_no_scale = """
637+
{ "Scalar": {
638+
"Decimal": [
639+
1,
640+
0
641+
]
642+
} }
643+
"""
644+
assert _pl_tree_to_sql(json.loads(scalar_decimal_no_scale)) == "1"
645+
646+
scalar_decimal_scale = """
647+
{ "Scalar": {
648+
"Decimal": [
649+
1,
650+
38,
651+
0
652+
]
653+
} }
654+
"""
655+
assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1"

0 commit comments

Comments
 (0)