Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 0.16

### 0.16.3

- fix: Ensure existing functions are normalized in all cases.

### 0.16.2

- fix: Cast pg_proc char columns to Text explicitly
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlalchemy-declarative-extensions"
version = "0.16.2"
version = "0.16.3"
authors = [
{name = "Dan Cardin", email = "[email protected]"},
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,6 @@ def from_unknown(
)

return cls(table=table_return_params)
# # Basic normalization: lowercase and remove extra spaces
# # This might need refinement for complex TABLE definitions
# inner_content = returns_lower[len("table(") : -1].strip()
# cols = [normalize_arg(c) for c in inner_content.split(",")]
# normalized_returns = f"table({', '.join(cols)})"
# return cls()

# Normalize base return type (including array types)
norm_type = type_map.get(returns_lower, returns_lower)
Expand Down
7 changes: 4 additions & 3 deletions src/sqlalchemy_declarative_extensions/function/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper

raw_existing_functions = get_functions(connection)
existing_functions = filter_functions(raw_existing_functions, functions.ignore)
existing_functions_by_name = {r.qualified_name: r for r in existing_functions}
existing_functions_by_name = {
f.qualified_name: f.normalize() for f in existing_functions
}
existing_function_names = set(existing_functions_by_name)

new_function_names = expected_function_names - existing_function_names
Expand All @@ -75,8 +77,7 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
result.append(CreateFunctionOp(normalized_function))
else:
existing_function = existing_functions_by_name[function_name]

normalized_existing_function = existing_function.normalize()
normalized_existing_function = existing_function

if normalized_existing_function != normalized_function:
result.append(
Expand Down
55 changes: 55 additions & 0 deletions tests/function/test_function_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import text

from sqlalchemy_declarative_extensions import (
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.function.compare import (
DropFunctionOp,
compare_functions,
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

_Base = declarative_base()


@declarative_database
class Base(_Base): # type: ignore
__abstract__ = True

functions: list = []


register_sqlalchemy_events(Base.metadata, functions=True)

pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)


def test_existing_function_normalized(pg):
Base.metadata.create_all(bind=pg.connection())
pg.commit()

pg.execute(
text(
"""
CREATE OR REPLACE FUNCTION echo_any_element(
input_value ANYELEMENT
)
RETURNS ANYELEMENT
LANGUAGE sql
AS $$
-- A generic function using the ANYELEMENT polymorphic type
SELECT input_value;
$$;
"""
)
)

connection = pg.connection()
diff = compare_functions(connection, Base.metadata.info["functions"])

assert len(diff) == 1
assert isinstance(diff[0], DropFunctionOp)
assert "DROP FUNCTION" in diff[0].to_sql()[0]
assert "CREATE FUNCTION" in diff[0].reverse().to_sql()[0]
5 changes: 2 additions & 3 deletions tests/schema/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from sqlalchemy import Column, text, types
from sqlalchemy import text
from sqlalchemy.engine import Engine, create_engine

from sqlalchemy_declarative_extensions import (
declarative_database,
register_sqlalchemy_events,
view,
)
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base

Expand All @@ -31,7 +30,7 @@ def test_create_schemas_filtered_to_database(snowflake: Engine):
"""
Base.metadata.create_all(bind=snowflake)

engine = create_engine('snowflake://user:password@account/db/schema')
engine = create_engine("snowflake://user:password@account/db/schema")
with engine.connect() as conn:
Base.metadata.create_all(engine)

Expand Down
2 changes: 1 addition & 1 deletion tests/trigger/test_drop_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Foo(Base):


class TableWithSpecialName(Base):
__tablename__ = "user" # This name will trip up unquoted identifiers
__tablename__ = "user" # This name will trip up unquoted identifiers

id = Column(types.Integer(), primary_key=True)

Expand Down
15 changes: 9 additions & 6 deletions tests/view/test_escaped_bindparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
register_sqlalchemy_events,
view,
)
from sqlalchemy_declarative_extensions.alembic.view import UpdateViewOp
from sqlalchemy_declarative_extensions.dialects import postgresql
from sqlalchemy_declarative_extensions.dialects.postgresql import View
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
from sqlalchemy_declarative_extensions.view.compare import compare_views

_Base = declarative_base()

Expand Down Expand Up @@ -48,6 +45,12 @@ def test_escape_bindparam_postgres(pg):

# Make sure that bindparams escaping doesn't create unnecessary escapes
# for the literal casts that appear after view definition round-tripping
rendered = View("simple_select", "SELECT 'a' as col1").render_definition(pg.connection())
assert "::" in rendered, "Literals in the view definition are expected to get explicit type casts"
assert "\\:\\:" not in rendered, "Bind parameters escaping should leave type casts unescaped"
rendered = View("simple_select", "SELECT 'a' as col1").render_definition(
pg.connection()
)
assert (
"::" in rendered
), "Literals in the view definition are expected to get explicit type casts"
assert (
"\\:\\:" not in rendered
), "Bind parameters escaping should leave type casts unescaped"
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading