Skip to content

Commit 452fe1c

Browse files
committed
fix polars tests
1 parent c534560 commit 452fe1c

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

tests/fast/arrow/test_polars.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33

44
import pytest
5+
from packaging.version import parse as parse_version
56

67
import duckdb
78

@@ -11,6 +12,8 @@
1112

1213
from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402
1314

15+
pl_pre_1_36_0 = parse_version(pl.__version__) < parse_version("1.36.0")
16+
1417

1518
def valid_filter(filter):
1619
sql_expression = _predicate_to_expression(filter)
@@ -86,10 +89,18 @@ def test_polars_from_json(self, duckdb_cursor):
8689
res = duckdb_cursor.read_json(string).pl()
8790
assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}"
8891

89-
@pytest.mark.skipif(
90-
not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version"
91-
)
92-
def test_polars_from_json_error(self, duckdb_cursor):
92+
@pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 doesn't support arrow extensions")
93+
def test_polars_from_json_post_pl_1_36_0(self, duckdb_cursor):
94+
from io import StringIO
95+
96+
duckdb_cursor.sql("set arrow_lossless_conversion=true")
97+
string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""")
98+
pl.register_extension_type("arrow.json", pl.Extension)
99+
res = duckdb_cursor.read_json(string).pl()
100+
assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}"
101+
102+
@pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 supports arrow extensions")
103+
def test_polars_from_json_pre_pl_1_36_0(self, duckdb_cursor):
93104
from io import StringIO
94105

95106
duckdb_cursor.sql("set arrow_lossless_conversion=true")
@@ -426,13 +437,34 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor):
426437
lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2
427438
)
428439

429-
# Validate Filter
440+
@pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates")
441+
def test_polars_predicate_to_expression_post_1_36_0(self):
442+
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
443+
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)
444+
ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1)
445+
# Validate filters - none of these produce casts in polars >= 1.36.0
446+
valid_filter(pl.col("a") == ts_2008)
447+
valid_filter(pl.col("a") > ts_2008)
448+
valid_filter(pl.col("a") >= ts_2010)
449+
valid_filter(pl.col("a") < ts_2010)
450+
valid_filter(pl.col("a") <= ts_2010)
451+
valid_filter(pl.col("a").is_null())
452+
valid_filter(pl.col("a").is_not_null())
453+
valid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008))
454+
valid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020))
455+
valid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008))
456+
457+
@pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts")
458+
def test_polars_predicate_to_expression_pre_1_36_0(self):
459+
ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1)
460+
ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1)
461+
ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1)
462+
# Validate filters
430463
invalid_filter(pl.col("a") == ts_2008)
431464
invalid_filter(pl.col("a") > ts_2008)
432465
invalid_filter(pl.col("a") >= ts_2010)
433466
invalid_filter(pl.col("a") < ts_2010)
434467
invalid_filter(pl.col("a") <= ts_2010)
435-
# These two are actually valid because they don't produce a cast
436468
valid_filter(pl.col("a").is_null())
437469
valid_filter(pl.col("a").is_not_null())
438470
invalid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008))

0 commit comments

Comments
 (0)