Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions duckdb/polars_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def _pl_operation_to_sql(op: str) -> str:
raise NotImplementedError(op)


def _escape_sql_identifier(identifier: str) -> str:
"""
Escape SQL identifiers by doubling any double quotes and wrapping in double quotes.

Example:
>>> _escape_sql_identifier('column"name')
'"column""name"'
"""
escaped = identifier.replace('"', '""')
return f'"{escaped}"'


def _pl_tree_to_sql(tree: dict) -> str:
"""
Recursively convert a Polars expression tree (as JSON) to a SQL string.
Expand Down Expand Up @@ -95,7 +107,8 @@ def _pl_tree_to_sql(tree: dict) -> str:
)
if node_type == "Column":
# A reference to a column name
return subtree
# Wrap in quotes to handle special characters
return _escape_sql_identifier(subtree)

if node_type in ("Literal", "Dyn"):
# Recursively process dynamic or literal values
Expand Down Expand Up @@ -196,7 +209,7 @@ def source_generator(
duck_predicate = None
relation_final = relation
if with_columns is not None:
cols = ",".join(with_columns)
cols = ",".join(map(_escape_sql_identifier, with_columns))
relation_final = relation_final.project(cols)
if n_rows is not None:
relation_final = relation_final.limit(n_rows)
Expand All @@ -213,7 +226,6 @@ def source_generator(
while True:
try:
record_batch = results.read_next_batch()
df = pl.from_arrow(record_batch)
if predicate is not None and duck_predicate is None:
# We have a predicate, but did not manage to push it down, we fallback here
yield pl.from_arrow(record_batch).filter(predicate)
Expand Down
30 changes: 30 additions & 0 deletions tests/fast/arrow/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,36 @@ def test_polars_lazy(self, duckdb_cursor):
]
assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}]

def test_polars_column_with_tricky_name(self, duckdb_cursor):
# Test that a polars DataFrame with a column name that is non standard still works
df_colon = pl.DataFrame({"x:y": [1, 2]})
lf = duckdb_cursor.sql("from df_colon").pl(lazy=True)
result = lf.select(pl.all()).collect()
assert result.to_dicts() == [{"x:y": 1}, {"x:y": 2}]
result = lf.select(pl.all()).filter(pl.col("x:y") == 1).collect()
assert result.to_dicts() == [{"x:y": 1}]

df_space = pl.DataFrame({"x y": [1, 2]})
lf = duckdb_cursor.sql("from df_space").pl(lazy=True)
result = lf.select(pl.all()).collect()
assert result.to_dicts() == [{"x y": 1}, {"x y": 2}]
result = lf.select(pl.all()).filter(pl.col("x y") == 1).collect()
assert result.to_dicts() == [{"x y": 1}]

df_dot = pl.DataFrame({"x.y": [1, 2]})
lf = duckdb_cursor.sql("from df_dot").pl(lazy=True)
result = lf.select(pl.all()).collect()
assert result.to_dicts() == [{"x.y": 1}, {"x.y": 2}]
result = lf.select(pl.all()).filter(pl.col("x.y") == 1).collect()
assert result.to_dicts() == [{"x.y": 1}]

df_quote = pl.DataFrame({'"xy"': [1, 2]})
lf = duckdb_cursor.sql("from df_quote").pl(lazy=True)
result = lf.select(pl.all()).collect()
assert result.to_dicts() == [{'"xy"': 1}, {'"xy"': 2}]
result = lf.select(pl.all()).filter(pl.col('"xy"') == 1).collect()
assert result.to_dicts() == [{'"xy"': 1}]

@pytest.mark.parametrize(
'data_type',
[
Expand Down