diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index dbe8727b..d8d4cfe9 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -58,6 +58,18 @@ def _pl_operation_to_sql(op: str) -> str: raise NotImplementedError(op) +def _escape_sql_identifier(identifier: str) -> str: + """ + Escape SQL identifiers by doubling any double quotes and wrapping in double quotes. + + Example: + >>> _escape_sql_identifier('column"name') + '"column""name"' + """ + escaped = identifier.replace('"', '""') + return f'"{escaped}"' + + def _pl_tree_to_sql(tree: dict) -> str: """ Recursively convert a Polars expression tree (as JSON) to a SQL string. @@ -95,7 +107,8 @@ def _pl_tree_to_sql(tree: dict) -> str: ) if node_type == "Column": # A reference to a column name - return subtree + # Wrap in quotes to handle special characters + return _escape_sql_identifier(subtree) if node_type in ("Literal", "Dyn"): # Recursively process dynamic or literal values @@ -196,7 +209,7 @@ def source_generator( duck_predicate = None relation_final = relation if with_columns is not None: - cols = ",".join(with_columns) + cols = ",".join(map(_escape_sql_identifier, with_columns)) relation_final = relation_final.project(cols) if n_rows is not None: relation_final = relation_final.limit(n_rows) @@ -213,7 +226,6 @@ def source_generator( while True: try: record_batch = results.read_next_batch() - df = pl.from_arrow(record_batch) if predicate is not None and duck_predicate is None: # We have a predicate, but did not manage to push it down, we fallback here yield pl.from_arrow(record_batch).filter(predicate) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 89ccf031..87e2f726 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -131,6 +131,36 @@ def test_polars_lazy(self, duckdb_cursor): ] assert lazy_df.filter(pl.col("b") < 32).select('a').collect().to_dicts() == [{'a': 'Mark'}, {'a': 'Thijs'}] + def test_polars_column_with_tricky_name(self, duckdb_cursor): + # Test that a polars DataFrame with a column name that is non standard still works + df_colon = pl.DataFrame({"x:y": [1, 2]}) + lf = duckdb_cursor.sql("from df_colon").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x:y": 1}, {"x:y": 2}] + result = lf.select(pl.all()).filter(pl.col("x:y") == 1).collect() + assert result.to_dicts() == [{"x:y": 1}] + + df_space = pl.DataFrame({"x y": [1, 2]}) + lf = duckdb_cursor.sql("from df_space").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x y": 1}, {"x y": 2}] + result = lf.select(pl.all()).filter(pl.col("x y") == 1).collect() + assert result.to_dicts() == [{"x y": 1}] + + df_dot = pl.DataFrame({"x.y": [1, 2]}) + lf = duckdb_cursor.sql("from df_dot").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{"x.y": 1}, {"x.y": 2}] + result = lf.select(pl.all()).filter(pl.col("x.y") == 1).collect() + assert result.to_dicts() == [{"x.y": 1}] + + df_quote = pl.DataFrame({'"xy"': [1, 2]}) + lf = duckdb_cursor.sql("from df_quote").pl(lazy=True) + result = lf.select(pl.all()).collect() + assert result.to_dicts() == [{'"xy"': 1}, {'"xy"': 2}] + result = lf.select(pl.all()).filter(pl.col('"xy"') == 1).collect() + assert result.to_dicts() == [{'"xy"': 1}] + @pytest.mark.parametrize( 'data_type', [