diff --git a/CHANGELOG.md b/CHANGELOG.md index 026512a..c4c992f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## 0.16 +### 0.16.3 + +- fix: Ensure existing functions are normalized in all cases. + ### 0.16.2 - fix: Cast pg_proc char columns to Text explicitly diff --git a/pyproject.toml b/pyproject.toml index 661802f..fa2966f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sqlalchemy-declarative-extensions" -version = "0.16.2" +version = "0.16.3" authors = [ {name = "Dan Cardin", email = "ddcardin@gmail.com"}, ] diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py index 2598b98..b45c1c8 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py @@ -255,12 +255,6 @@ def from_unknown( ) return cls(table=table_return_params) - # # Basic normalization: lowercase and remove extra spaces - # # This might need refinement for complex TABLE definitions - # inner_content = returns_lower[len("table(") : -1].strip() - # cols = [normalize_arg(c) for c in inner_content.split(",")] - # normalized_returns = f"table({', '.join(cols)})" - # return cls() # Normalize base return type (including array types) norm_type = type_map.get(returns_lower, returns_lower) diff --git a/src/sqlalchemy_declarative_extensions/function/compare.py b/src/sqlalchemy_declarative_extensions/function/compare.py index c5b6e06..d8eb472 100644 --- a/src/sqlalchemy_declarative_extensions/function/compare.py +++ b/src/sqlalchemy_declarative_extensions/function/compare.py @@ -56,7 +56,9 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper raw_existing_functions = get_functions(connection) existing_functions = filter_functions(raw_existing_functions, functions.ignore) - existing_functions_by_name = {r.qualified_name: r for r in existing_functions} + existing_functions_by_name = { + f.qualified_name: f.normalize() for f in existing_functions + } existing_function_names = set(existing_functions_by_name) new_function_names = expected_function_names - existing_function_names @@ -75,8 +77,7 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper result.append(CreateFunctionOp(normalized_function)) else: existing_function = existing_functions_by_name[function_name] - - normalized_existing_function = existing_function.normalize() + normalized_existing_function = existing_function if normalized_existing_function != normalized_function: result.append( diff --git a/tests/function/test_function_normalization.py b/tests/function/test_function_normalization.py new file mode 100644 index 0000000..afca50a --- /dev/null +++ b/tests/function/test_function_normalization.py @@ -0,0 +1,55 @@ +from pytest_mock_resources import create_postgres_fixture +from sqlalchemy import text + +from sqlalchemy_declarative_extensions import ( + declarative_database, + register_sqlalchemy_events, +) +from sqlalchemy_declarative_extensions.function.compare import ( + DropFunctionOp, + compare_functions, +) +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + functions: list = [] + + +register_sqlalchemy_events(Base.metadata, functions=True) + +pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True) + + +def test_existing_function_normalized(pg): + Base.metadata.create_all(bind=pg.connection()) + pg.commit() + + pg.execute( + text( + """ + CREATE OR REPLACE FUNCTION echo_any_element( + input_value ANYELEMENT + ) + RETURNS ANYELEMENT + LANGUAGE sql + AS $$ + -- A generic function using the ANYELEMENT polymorphic type + SELECT input_value; + $$; + """ + ) + ) + + connection = pg.connection() + diff = compare_functions(connection, Base.metadata.info["functions"]) + + assert len(diff) == 1 + assert isinstance(diff[0], DropFunctionOp) + assert "DROP FUNCTION" in diff[0].to_sql()[0] + assert "CREATE FUNCTION" in diff[0].reverse().to_sql()[0] diff --git a/tests/schema/test_snowflake.py b/tests/schema/test_snowflake.py index 85264ce..db80fc0 100644 --- a/tests/schema/test_snowflake.py +++ b/tests/schema/test_snowflake.py @@ -1,10 +1,9 @@ -from sqlalchemy import Column, text, types +from sqlalchemy import text from sqlalchemy.engine import Engine, create_engine from sqlalchemy_declarative_extensions import ( declarative_database, register_sqlalchemy_events, - view, ) from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base @@ -31,7 +30,7 @@ def test_create_schemas_filtered_to_database(snowflake: Engine): """ Base.metadata.create_all(bind=snowflake) - engine = create_engine('snowflake://user:password@account/db/schema') + engine = create_engine("snowflake://user:password@account/db/schema") with engine.connect() as conn: Base.metadata.create_all(engine) diff --git a/tests/trigger/test_drop_postgres.py b/tests/trigger/test_drop_postgres.py index 8d2d1bc..0881ea2 100644 --- a/tests/trigger/test_drop_postgres.py +++ b/tests/trigger/test_drop_postgres.py @@ -26,7 +26,7 @@ class Foo(Base): class TableWithSpecialName(Base): - __tablename__ = "user" # This name will trip up unquoted identifiers + __tablename__ = "user" # This name will trip up unquoted identifiers id = Column(types.Integer(), primary_key=True) diff --git a/tests/view/test_escaped_bindparam.py b/tests/view/test_escaped_bindparam.py index cf34f2d..9eed57e 100644 --- a/tests/view/test_escaped_bindparam.py +++ b/tests/view/test_escaped_bindparam.py @@ -6,11 +6,8 @@ register_sqlalchemy_events, view, ) -from sqlalchemy_declarative_extensions.alembic.view import UpdateViewOp -from sqlalchemy_declarative_extensions.dialects import postgresql from sqlalchemy_declarative_extensions.dialects.postgresql import View from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base -from sqlalchemy_declarative_extensions.view.compare import compare_views _Base = declarative_base() @@ -48,6 +45,12 @@ def test_escape_bindparam_postgres(pg): # Make sure that bindparams escaping doesn't create unnecessary escapes # for the literal casts that appear after view definition round-tripping - rendered = View("simple_select", "SELECT 'a' as col1").render_definition(pg.connection()) - assert "::" in rendered, "Literals in the view definition are expected to get explicit type casts" - assert "\\:\\:" not in rendered, "Bind parameters escaping should leave type casts unescaped" + rendered = View("simple_select", "SELECT 'a' as col1").render_definition( + pg.connection() + ) + assert ( + "::" in rendered + ), "Literals in the view definition are expected to get explicit type casts" + assert ( + "\\:\\:" not in rendered + ), "Bind parameters escaping should leave type casts unescaped" diff --git a/uv.lock b/uv.lock index 96df286..93d097f 100644 --- a/uv.lock +++ b/uv.lock @@ -1415,7 +1415,7 @@ mypy = [ [[package]] name = "sqlalchemy-declarative-extensions" -version = "0.15.13" +version = "0.16.3" source = { editable = "." } dependencies = [ { name = "sqlalchemy" },