Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

### 0.17.0

- feat: Add support for Postgres sqlbody functions.

## 0.16

### 0.16.4
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlalchemy-declarative-extensions"
version = "0.16.4"
version = "0.16.5"
authors = [
{name = "Dan Cardin", email = "[email protected]"},
]
Expand Down Expand Up @@ -45,9 +45,8 @@ changelog = "https://github.com/DanCardin/sqlalchemy-declarative-extensions/blob
alembic = ["alembic >= 1.0"]
parse = ["sqlglot"]

[tool.uv]
environments = ["python_version < '3.9'", "python_version >= '3.9' and python_version < '4'"]
dev-dependencies = [
[dependency-groups]
dev = [
"alembic-utils >= 0.8.1",
"coverage >= 5",
"mypy == 1.8.0",
Expand All @@ -58,7 +57,7 @@ dev-dependencies = [
"pytest-xdist",
"ruff >= 0.5.0",
"sqlalchemy[mypy] >= 1.4",
"psycopg",
"psycopg[binary]",
"psycopg2-binary",

# snowflake
Expand All @@ -67,6 +66,9 @@ dev-dependencies = [
"snowflake-sqlalchemy >= 1.6.0; python_version >= '3.9'",
]

[tool.uv]
environments = ["python_version < '3.9'", "python_version >= '3.9' and python_version < '4'"]

[tool.mypy]
strict_optional = true
ignore_missing_imports = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import enum
import re
import textwrap
from dataclasses import dataclass, replace
from typing import Any, List, Literal, Sequence, Tuple, cast
Expand All @@ -10,6 +11,22 @@
from sqlalchemy_declarative_extensions.function import base
from sqlalchemy_declarative_extensions.sql import quote_name

_sqlbody_regex = re.compile(r"\W*(BEGIN ATOMIC|RETURN)\W", re.IGNORECASE | re.MULTILINE)
"""sql_body
The body of a LANGUAGE SQL function. This can either be a single statement

RETURN expression

or a block

BEGIN ATOMIC
statement;
statement;
...
statement;
END
"""


@enum.unique
class FunctionSecurity(enum.Enum):
Expand Down Expand Up @@ -47,6 +64,12 @@ class Function(base.Function):
parameters: Sequence[FunctionParam | str] | None = None # type: ignore
volatility: FunctionVolatility = FunctionVolatility.VOLATILE

@property
def _has_sqlbody(self) -> bool:
return self.language.lower() == "sql" and bool(
_sqlbody_regex.match(self.definition)
)

def to_sql_create(self, replace=False) -> list[str]:
components = ["CREATE"]

Expand All @@ -72,7 +95,10 @@ def to_sql_create(self, replace=False) -> list[str]:
components.append(self.volatility.value)

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")
if self._has_sqlbody:
components.append(self.definition)
else:
components.append(f"AS $${self.definition}$$")

return [" ".join(components) + ";"]

Expand All @@ -96,6 +122,8 @@ def with_security_definer(self):

def normalize(self) -> Function:
definition = textwrap.dedent(self.definition)
if self._has_sqlbody:
definition = definition.strip()

# Normalize parameter types
parameters = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import enum
import re
import textwrap
from dataclasses import dataclass, replace

Expand All @@ -9,6 +10,18 @@
from sqlalchemy_declarative_extensions.procedure import base
from sqlalchemy_declarative_extensions.sql import quote_name

_sqlbody_regex = re.compile(r"\W?(BEGIN ATOMIC)\W", re.IGNORECASE | re.MULTILINE)
"""sql_body
The body of a LANGUAGE SQL procedure. This should be a block

BEGIN ATOMIC
statement;
statement;
...
statement;
END
"""


@enum.unique
class ProcedureSecurity(enum.Enum):
Expand All @@ -27,6 +40,12 @@ class Procedure(base.Procedure):

security: ProcedureSecurity = ProcedureSecurity.invoker

@property
def _has_sqlbody(self) -> bool:
return self.language.lower() == "sql" and bool(
_sqlbody_regex.match(self.definition)
)

def to_sql_create(self, replace=False) -> list[str]:
components = ["CREATE"]

Expand All @@ -40,7 +59,10 @@ def to_sql_create(self, replace=False) -> list[str]:
components.append("SECURITY DEFINER")

components.append(f"LANGUAGE {self.language}")
components.append(f"AS $${self.definition}$$")
if self._has_sqlbody:
components.append(self.definition)
else:
components.append(f"AS $${self.definition}$$")

return [" ".join(components) + ";"]

Expand All @@ -49,6 +71,8 @@ def to_sql_update(self) -> list[str]:

def normalize(self) -> Self:
definition = textwrap.dedent(self.definition)
if self._has_sqlbody:
definition = definition.strip()
return replace(self, definition=definition)

def with_security(self, security: ProcedureSecurity):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
databases_query,
default_acl_query,
extensions_query,
functions_query,
get_functions_query,
get_procedures_query,
object_acl_query,
objects_query,
procedures_query,
roles_query,
schema_exists_query,
schemas_query,
Expand Down Expand Up @@ -195,6 +195,7 @@ def get_view_postgresql(connection: Connection, name: str, schema: str = "public

def get_procedures_postgresql(connection: Connection) -> Sequence[BaseProcedure]:
procedures = []
procedures_query = get_procedures_query(connection.dialect.server_version_info)
for f in connection.execute(procedures_query).fetchall():
name = f.name
definition = f.source
Expand Down Expand Up @@ -225,6 +226,7 @@ def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]:
)

functions = []
functions_query = get_functions_query(connection.dialect.server_version_info)

for f in connection.execute(functions_query).fetchall():
name = f.name
Expand Down
138 changes: 82 additions & 56 deletions src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
union,
)
from sqlalchemy.dialects.postgresql import ARRAY, CHAR, REGCLASS, aggregate_order_by
from sqlalchemy.sql.functions import coalesce

from sqlalchemy_declarative_extensions.sqlalchemy import select

Expand Down Expand Up @@ -317,65 +318,90 @@ def get_types(arg_type_oids):
)


procedures_query = (
select(
pg_proc.c.proname.label("name"),
pg_namespace.c.nspname.label("schema"),
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("return_type"),
pg_proc.c.prosrc.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
pg_proc.c.prokind.label("kind"),
pg_proc.c.proargnames.label("arg_names"),
pg_proc.c.proargmodes.label("arg_modes"),
func.coalesce(
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
).label("arg_types"),
func.pg_get_expr(
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
).label("arg_defaults"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
def get_procedures_query(version_info):
source = get_source_column(version_info)
return (
select(
pg_proc.c.proname.label("name"),
pg_namespace.c.nspname.label("schema"),
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("return_type"),
source.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
pg_proc.c.prokind.label("kind"),
pg_proc.c.proargnames.label("arg_names"),
pg_proc.c.proargmodes.label("arg_modes"),
func.coalesce(
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
).label("arg_types"),
func.pg_get_expr(
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
).label("arg_defaults"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
)
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
.where(pg_proc.c.prokind == "p")
.where(_schema_not_from_extension())
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
)
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
.where(pg_proc.c.prokind == "p")
.where(_schema_not_from_extension())
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
)

functions_query = (
select(
pg_proc.c.proname.label("name"),
pg_namespace.c.nspname.label("schema"),
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("base_return_type"),
pg_proc.c.prosrc.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
cast(pg_proc.c.prokind, Text).label("kind"),
func.pg_get_function_arguments(pg_proc.c.oid).label("parameters"),
cast(pg_proc.c.provolatile, Text).label("volatility"),
func.pg_get_function_result(pg_proc.c.oid).label("return_type_string"),
pg_proc.c.proargnames.label("arg_names"),
pg_proc.c.proargmodes.label("arg_modes"),
func.coalesce(
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
).label("arg_types"),
func.pg_get_expr(
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
).label("arg_defaults"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)

def get_functions_query(version_info):
source = get_source_column(version_info)
return (
select(
pg_proc.c.proname.label("name"),
pg_namespace.c.nspname.label("schema"),
pg_language.c.lanname.label("language"),
pg_type.c.typname.label("base_return_type"),
source.label("source"),
pg_proc.c.prosecdef.label("security_definer"),
cast(pg_proc.c.prokind, Text).label("kind"),
func.pg_get_function_arguments(pg_proc.c.oid).label("parameters"),
cast(pg_proc.c.provolatile, Text).label("volatility"),
func.pg_get_function_result(pg_proc.c.oid).label("return_type_string"),
pg_proc.c.proargnames.label("arg_names"),
pg_proc.c.proargmodes.label("arg_modes"),
func.coalesce(
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
).label("arg_types"),
func.pg_get_expr(
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
).label("arg_defaults"),
)
.select_from(
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
)
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
.where(pg_proc.c.prokind != "p")
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
)
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
.where(pg_proc.c.prokind != "p")
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
)


def get_source_column(version_info):
"""Postgres 14 introduced SQL-standard function and procedure bodies.

When writing a function or procedure in SQL-standard syntax, the body is parsed
immediately and stored as a parse tree. This allows better tracking of function
dependencies, and can have security benefits.

For these sqlbody functions, the pg_proc.prosrc column is an empty string.
The pre-parsed SQL function body is stored in pg_proc.prosqlbody as pg_node_tree.
The text representation can be returned along with the full ddl with
pg_get_functiondef.
Alternatively pg_get_function_sqlbody(pg_proc.oid) can be called to just get the
body. This function is not documented, see source:
https://doxygen.postgresql.org/ruleutils_8c.html#a99a3f975518b6b1707a3159c5f80427e
"""
if version_info >= (14, 0):
return coalesce(func.pg_get_function_sqlbody(pg_proc.c.oid), pg_proc.c.prosrc)
return pg_proc.c.prosrc


rel_nsp = pg_namespace.alias("rel_nsp")
Expand Down
34 changes: 34 additions & 0 deletions tests/dialect/postgresql/test_postgres_14_sqlbody.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
from sqlalchemy import text

from sqlalchemy_declarative_extensions import Functions
from sqlalchemy_declarative_extensions.dialects.postgresql import (
Function,
FunctionVolatility,
)
from sqlalchemy_declarative_extensions.function.compare import compare_functions

pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})


@pytest.fixture
def pmr_postgres_config():
return PostgresConfig(image="postgres:14", port=None, ci_port=None)


def test_functions(pg):
add_stable_function = Function(
name="add_stable",
definition="RETURN (i + 1)",
parameters=["i integer"],
returns="INTEGER",
volatility=FunctionVolatility.STABLE,
).normalize()
create_function = add_stable_function.to_sql_create()
functions = Functions([add_stable_function])
with pg.connect() as connection:
connection.execute(text("\n".join(create_function)))
diff = compare_functions(connection, functions)
for op in diff:
assert op.from_function == op.function
Loading
Loading