|
2 | 2 | import json |
3 | 3 |
|
4 | 4 | import pytest |
| 5 | +from packaging.version import parse as parse_version |
5 | 6 |
|
6 | 7 | import duckdb |
7 | 8 |
|
|
11 | 12 |
|
12 | 13 | from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402 |
13 | 14 |
|
| 15 | +pl_pre_1_36_0 = parse_version(pl.__version__) < parse_version("1.36.0") |
| 16 | + |
14 | 17 |
|
15 | 18 | def valid_filter(filter): |
16 | 19 | sql_expression = _predicate_to_expression(filter) |
@@ -86,10 +89,18 @@ def test_polars_from_json(self, duckdb_cursor): |
86 | 89 | res = duckdb_cursor.read_json(string).pl() |
87 | 90 | assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" |
88 | 91 |
|
89 | | - @pytest.mark.skipif( |
90 | | - not hasattr(pl.exceptions, "PanicException"), reason="Polars has no PanicException in this version" |
91 | | - ) |
92 | | - def test_polars_from_json_error(self, duckdb_cursor): |
| 92 | + @pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 doesn't support arrow extensions") |
| 93 | + def test_polars_from_json_post_pl_1_36_0(self, duckdb_cursor): |
| 94 | + from io import StringIO |
| 95 | + |
| 96 | + duckdb_cursor.sql("set arrow_lossless_conversion=true") |
| 97 | + string = StringIO("""{"entry":[{"content":{"ManagedSystem":{"test":null}}}]}""") |
| 98 | + pl.register_extension_type("arrow.json", pl.Extension) |
| 99 | + res = duckdb_cursor.read_json(string).pl() |
| 100 | + assert str(res["entry"][0][0]) == "{'content': {'ManagedSystem': {'test': None}}}" |
| 101 | + |
| 102 | + @pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 supports arrow extensions") |
| 103 | + def test_polars_from_json_pre_pl_1_36_0(self, duckdb_cursor): |
93 | 104 | from io import StringIO |
94 | 105 |
|
95 | 106 | duckdb_cursor.sql("set arrow_lossless_conversion=true") |
@@ -426,13 +437,34 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor): |
426 | 437 | lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2 |
427 | 438 | ) |
428 | 439 |
|
429 | | - # Validate Filter |
| 440 | + @pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates") |
| 441 | + def test_polars_predicate_to_expression_post_1_36_0(self): |
| 442 | + ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1) |
| 443 | + ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) |
| 444 | + ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1) |
| 445 | + # Validate filters - none of these produce casts in polars >= 1.36.0 |
| 446 | + valid_filter(pl.col("a") == ts_2008) |
| 447 | + valid_filter(pl.col("a") > ts_2008) |
| 448 | + valid_filter(pl.col("a") >= ts_2010) |
| 449 | + valid_filter(pl.col("a") < ts_2010) |
| 450 | + valid_filter(pl.col("a") <= ts_2010) |
| 451 | + valid_filter(pl.col("a").is_null()) |
| 452 | + valid_filter(pl.col("a").is_not_null()) |
| 453 | + valid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008)) |
| 454 | + valid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020)) |
| 455 | + valid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)) |
| 456 | + |
| 457 | + @pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts") |
| 458 | + def test_polars_predicate_to_expression_pre_1_36_0(self): |
| 459 | + ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1) |
| 460 | + ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) |
| 461 | + ts_2020 = datetime.datetime(2020, 3, 1, 10, 0, 1) |
| 462 | + # Validate filters |
430 | 463 | invalid_filter(pl.col("a") == ts_2008) |
431 | 464 | invalid_filter(pl.col("a") > ts_2008) |
432 | 465 | invalid_filter(pl.col("a") >= ts_2010) |
433 | 466 | invalid_filter(pl.col("a") < ts_2010) |
434 | 467 | invalid_filter(pl.col("a") <= ts_2010) |
435 | | - # These two are actually valid because they don't produce a cast |
436 | 468 | valid_filter(pl.col("a").is_null()) |
437 | 469 | valid_filter(pl.col("a").is_not_null()) |
438 | 470 | invalid_filter((pl.col("a") == ts_2010) & (pl.col("b") == ts_2008)) |
|
0 commit comments