Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 0.16

### 0.16.4

- fix: Handle unnamed parameters during normalization and CREATE statements.

### 0.16.3

- fix: Ensure existing functions are normalized in all cases.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlalchemy-declarative-extensions"
version = "0.16.3"
version = "0.16.4"
authors = [
{name = "Dan Cardin", email = "[email protected]"},
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,6 @@ def from_provolatile(cls, provolatile: str) -> FunctionVolatility:
raise ValueError(f"Invalid volatility: {provolatile}")


# def normalize_arg(arg: str) -> str:
# parts = arg.strip().split(maxsplit=1)
# if len(parts) == 2:
# name, type_str = parts
# norm_type = type_map.get(type_str.lower(), type_str.lower())
# # Handle array types
# if norm_type.endswith("[]"):
# base_type = norm_type[:-2]
# norm_base_type = type_map.get(base_type, base_type)
# norm_type = f"{norm_base_type}[]"
#
# return f"{name} {norm_type}"
# # Handle case where it might just be the type (e.g., from DROP FUNCTION)
# type_str = arg.strip()
# norm_type = type_map.get(type_str.lower(), type_str.lower())
# if norm_type.endswith("[]"):
# base_type = norm_type[:-2]
# norm_base_type = type_map.get(base_type, base_type)
# norm_type = f"{norm_base_type}[]"
# return norm_type


@dataclass
class Function(base.Function):
"""Describes a PostgreSQL function.
Expand Down Expand Up @@ -143,7 +121,7 @@ def normalize(self) -> Function:

@dataclass
class FunctionParam:
name: str
name: str | None
type: str
default: Any | None = None
mode: Literal["i", "o", "b", "v", "t"] | None = None
Expand Down Expand Up @@ -185,31 +163,40 @@ def from_unknown(
if isinstance(source_param, tuple):
return cls(*source_param)

name, type = source_param.strip().split(maxsplit=1)
try:
name, type = source_param.strip().split(maxsplit=1)
except ValueError:
name = None
type = source_param.strip()

return cls(name, type)

def normalize(self) -> FunctionParam:
type = self.type.lower()
return replace(
self,
name=self.name.lower(),
name=self.name.lower() if self.name is not None else None,
mode=self.mode or "i",
type=type_map.get(type, type),
default=str(self.default) if self.default is not None else None,
)

def to_sql_create(self) -> str:
result = ""
segments = []
if self.mode:
result += {"o": "OUT ", "b": "INOUT ", "v": "VARIADIC ", "t": "TABLE "}.get(
self.mode, ""
)
modes = {"o": "OUT ", "b": "INOUT ", "v": "VARIADIC ", "t": "TABLE "}
mode = modes.get(self.mode)
if mode:
segments.append(mode)

if self.name:
segments.append(self.name)

result += f"{self.name} {self.type}"
segments.append(self.type)

if self.default is not None:
result += f" DEFAULT {self.default}"
return result
segments.append(f"DEFAULT {self.default}")
return " ".join(segments)

def to_sql_drop(self) -> str:
return self.type
Expand Down Expand Up @@ -245,7 +232,7 @@ def from_unknown(
returns_lower = source.lower().strip()
if returns_lower.startswith("table("):
table_return_params = [
(p.name, p.type) for p in (parameters or []) if p.mode == "t"
(p.name, p.type) for p in (parameters or []) if p.name and p.mode == "t"
]

if not table_return_params:
Expand Down
64 changes: 64 additions & 0 deletions tests/function/test_unnamed_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import text

from sqlalchemy_declarative_extensions.dialects.postgresql.function import Function
from sqlalchemy_declarative_extensions.function.base import Functions
from sqlalchemy_declarative_extensions.function.compare import (
DropFunctionOp,
compare_functions,
)

pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)


def test_existing_function(pg):
pg.execute(
text(
"""
CREATE OR REPLACE FUNCTION add_numbers(integer, integer)
RETURNS integer AS $$
BEGIN
RETURN $1 + $2;
END;
$$ LANGUAGE plpgsql;
"""
)
)

connection = pg.connection()
diff = compare_functions(connection, Functions())

assert len(diff) == 1
assert isinstance(diff[0], DropFunctionOp)
assert "DROP FUNCTION" in diff[0].to_sql()[0]
assert "CREATE FUNCTION" in diff[0].reverse().to_sql()[0]


def test_creates_function(pg):
connection = pg.connection()
diff = compare_functions(
connection,
Functions(
functions=[
Function(
"add_numbers",
"""
BEGIN
RETURN $1 + $2;
END;
""",
language="plpgsql",
returns="integer",
parameters=["integer", "integer"],
)
]
),
)

assert len(diff) == 1
assert "CREATE FUNCTION" in diff[0].to_sql()[0]
pg.execute(text(diff[0].to_sql()[0]))
pg.commit()

result = pg.execute(text("SELECT add_numbers(2, 3)")).one()[0]
assert result == 5
Loading
Loading