diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index f43e0afd..61ead5bc 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -1,5 +1,6 @@ from __future__ import annotations # noqa: D100 +import contextlib import datetime import json import typing @@ -176,9 +177,12 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: if dtype.startswith("{'Decimal'") or dtype == "Decimal": decimal_value = value["Decimal"] assert isinstance(decimal_value, list), ( - f"A {dtype} should be a two member list but got {type(decimal_value)}" + f"A {dtype} should be a two or three member list but got {type(decimal_value)}" ) - return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1])) + assert 2 <= len(decimal_value) <= 3, ( + f"A {dtype} should be a two or three member list but got {len(decimal_value)} member list" + ) + return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[-1])) # Datetime with microseconds since epoch if dtype.startswith("{'Datetime'") or dtype == "Datetime": @@ -260,7 +264,8 @@ def source_generator( relation_final = relation_final.limit(n_rows) if predicate is not None: # We have a predicate, if possible, we push it down to DuckDB - duck_predicate = _predicate_to_expression(predicate) + with contextlib.suppress(AssertionError, KeyError): + duck_predicate = _predicate_to_expression(predicate) # Try to pushdown filter, if one exists if duck_predicate is not None: relation_final = relation_final.filter(duck_predicate) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 705532c8..d5621701 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -1,4 +1,5 @@ import datetime +import json import pytest @@ -8,7 +9,7 @@ arrow = pytest.importorskip("pyarrow") pl_testing = pytest.importorskip("polars.testing") -from duckdb.polars_io import _predicate_to_expression # noqa: E402 +from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402 def valid_filter(filter): @@ -175,7 +176,7 @@ def test_polars_column_with_tricky_name(self, duckdb_cursor): "UBIGINT", "FLOAT", "DOUBLE", - # "HUGEINT", + "HUGEINT", "DECIMAL(4,1)", "DECIMAL(9,1)", "DECIMAL(18,4)", @@ -605,3 +606,50 @@ def test_polars_lazy_many_batches(self, duckdb_cursor): correct = duckdb_cursor.execute("FROM t").fetchall() assert res == correct + + def test_invalid_expr_json(self): + bad_key_expr = """ + { + "BinaryExpr": { + "left": { "Column": "foo" }, + "middle": "Gt", + "right": { "Literal": { "Int": 5 } } + } + } + """ + with pytest.raises(KeyError, match="'op'"): + _pl_tree_to_sql(json.loads(bad_key_expr)) + + bad_type_expr = """ + { + "BinaryExpr": { + "left": { "Column": [ "foo" ] }, + "op": "Gt", + "right": { "Literal": { "Int": 5 } } + } + } + """ + with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"): + _pl_tree_to_sql(json.loads(bad_type_expr)) + + def test_decimal_scale(self): + scalar_decimal_no_scale = """ + { "Scalar": { + "Decimal": [ + 1, + 0 + ] + } } + """ + assert _pl_tree_to_sql(json.loads(scalar_decimal_no_scale)) == "1" + + scalar_decimal_scale = """ + { "Scalar": { + "Decimal": [ + 1, + 38, + 0 + ] + } } + """ + assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1"