diff --git a/src/sqlalchemy_declarative_extensions/sqlalchemy.py b/src/sqlalchemy_declarative_extensions/sqlalchemy.py index b766710..d3ecf68 100644 --- a/src/sqlalchemy_declarative_extensions/sqlalchemy.py +++ b/src/sqlalchemy_declarative_extensions/sqlalchemy.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import Callable, TypeVar import sqlalchemy @@ -50,8 +51,12 @@ def dispatch(connection: Connection, *args: P.args, **kwargs: P.kwargs) -> T: return dispatch +# https://github.com/sqlalchemy/sqlalchemy/blob/2e9902a34fafff0ac6d6c521a86c7dea3d96a392/lib/sqlalchemy/sql/elements.py#L2334 +_sqlalchemy_bind_params_regex = re.compile(r"(? str: - return query.replace(":", r"\:") + return _sqlalchemy_bind_params_regex.sub(r"\\:\1", query) if version.startswith("1.3"): diff --git a/tests/view/test_escaped_bindparam.py b/tests/view/test_escaped_bindparam.py index 1f5724d..cf34f2d 100644 --- a/tests/view/test_escaped_bindparam.py +++ b/tests/view/test_escaped_bindparam.py @@ -6,7 +6,11 @@ register_sqlalchemy_events, view, ) +from sqlalchemy_declarative_extensions.alembic.view import UpdateViewOp +from sqlalchemy_declarative_extensions.dialects import postgresql +from sqlalchemy_declarative_extensions.dialects.postgresql import View from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base +from sqlalchemy_declarative_extensions.view.compare import compare_views _Base = declarative_base() @@ -41,3 +45,9 @@ def test_escape_bindparam_postgres(pg): result = pg.execute(text("select * from bar")).fetchall() assert result == [] + + # Make sure that bindparams escaping doesn't create unnecessary escapes + # for the literal casts that appear after view definition round-tripping + rendered = View("simple_select", "SELECT 'a' as col1").render_definition(pg.connection()) + assert "::" in rendered, "Literals in the view definition are expected to get explicit type casts" + assert "\\:\\:" not in rendered, "Bind parameters escaping should leave type casts unescaped"